diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp index 7ba200c3c5e322d35aeb8fb4870d95fdb0af0bd3..381b96a974c67607e07fcbc413ace81470a2fbc1 100644 --- a/torch_modules/src/Concat.cpp +++ b/torch_modules/src/Concat.cpp @@ -2,16 +2,21 @@ ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize) { - dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); + if (inputSize and outputSize) // if one of these is null, don't use a linear layer + dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); } torch::Tensor ConcatImpl::forward(torch::Tensor input) { - return dimReduce(input).view({input.size(0), -1}); + if (dimReduce) + return dimReduce(input).view({input.size(0), -1}); + return input.view({input.size(0), -1}); } int ConcatImpl::getOutputSize(int sequenceLength) { - return sequenceLength * outputSize; + if (outputSize) + return sequenceLength * outputSize; + return sequenceLength; } diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 5a1191d9887be93d221fc0b842e53b4e90c03ccc..0abf4acdcb4512c156e749bb5a45c168e4824285 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1,1)); + myModule = register_module("myModule", Concat(0,0)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}