From 311b201026800be4e8882088b001d320e53ccc3d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 13 Jul 2022 17:25:32 +0200 Subject: [PATCH] Removed dimension reduction for NumericColumn module --- torch_modules/src/Concat.cpp | 11 ++++++++--- torch_modules/src/NumericColumnModule.cpp | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp index 7ba200c..381b96a 100644 --- a/torch_modules/src/Concat.cpp +++ b/torch_modules/src/Concat.cpp @@ -2,16 +2,21 @@ ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize) { - dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); + if (inputSize and outputSize) // if one of these is null, don't use a linear layer + dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); } torch::Tensor ConcatImpl::forward(torch::Tensor input) { - return dimReduce(input).view({input.size(0), -1}); + if (dimReduce) + return dimReduce(input).view({input.size(0), -1}); + return input.view({input.size(0), -1}); } int ConcatImpl::getOutputSize(int sequenceLength) { - return sequenceLength * outputSize; + if (outputSize) + return sequenceLength * outputSize; + return sequenceLength; } diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 5a1191d..0abf4ac 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1,1)); + myModule = register_module("myModule", Concat(0,0)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} -- GitLab