From 311b201026800be4e8882088b001d320e53ccc3d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 13 Jul 2022 17:25:32 +0200
Subject: [PATCH] Removed dimension reduction for NumericColumn module

---
 torch_modules/src/Concat.cpp              | 11 ++++++++---
 torch_modules/src/NumericColumnModule.cpp |  2 +-
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp
index 7ba200c..381b96a 100644
--- a/torch_modules/src/Concat.cpp
+++ b/torch_modules/src/Concat.cpp
@@ -2,16 +2,21 @@
 
 ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize)
 {
-  dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
+  if (inputSize and outputSize) // if one of these is null, don't use a linear layer
+    dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
 }
 
 torch::Tensor ConcatImpl::forward(torch::Tensor input)
 {
-  return dimReduce(input).view({input.size(0), -1});
+  if (dimReduce)
+    return dimReduce(input).view({input.size(0), -1});
+  return input.view({input.size(0), -1});
 }
 
 int ConcatImpl::getOutputSize(int sequenceLength)
 {
-  return sequenceLength * outputSize;
+  if (outputSize)
+    return sequenceLength * outputSize;
+  return sequenceLength;
 }
 
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 5a1191d..0abf4ac 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(1, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(1,1));
+              myModule = register_module("myModule", Concat(0,0));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
-- 
GitLab