Skip to content
Snippets Groups Projects
Commit 311b2010 authored by Franck Dary's avatar Franck Dary
Browse files

Removed dimension reduction for NumericColumn module

parent e4ea7fb9
No related branches found
No related tags found
No related merge requests found
......@@ -2,16 +2,21 @@
ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize)
{
dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
if (inputSize and outputSize) // if one of these is null, don't use a linear layer
dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
}
torch::Tensor ConcatImpl::forward(torch::Tensor input)
{
return dimReduce(input).view({input.size(0), -1});
if (dimReduce)
return dimReduce(input).view({input.size(0), -1});
return input.view({input.size(0), -1});
}
int ConcatImpl::getOutputSize(int sequenceLength)
{
return sequenceLength * outputSize;
if (outputSize)
return sequenceLength * outputSize;
return sequenceLength;
}
......@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1,1));
myModule = register_module("myModule", Concat(0,0));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment