From 0652f5872625aadab909bc40d417ee7cf9195fe7 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 10 Oct 2021 14:16:04 +0200
Subject: [PATCH] Added dimmension reduction in Concat module

---
 reading_machine/src/Classifier.cpp                  | 3 ++-
 torch_modules/include/Concat.hpp                    | 4 +++-
 torch_modules/src/Concat.cpp                        | 7 ++++---
 torch_modules/src/ContextModule.cpp                 | 2 +-
 torch_modules/src/ContextualModule.cpp              | 2 +-
 torch_modules/src/DepthLayerTreeEmbeddingModule.cpp | 2 +-
 torch_modules/src/DistanceModule.cpp                | 2 +-
 torch_modules/src/FocusedColumnModule.cpp           | 2 +-
 torch_modules/src/HistoryMineModule.cpp             | 2 +-
 torch_modules/src/HistoryModule.cpp                 | 2 +-
 torch_modules/src/NumericColumnModule.cpp           | 2 +-
 torch_modules/src/RawInputModule.cpp                | 2 +-
 torch_modules/src/SplitTransModule.cpp              | 2 +-
 torch_modules/src/UppercaseRateModule.cpp           | 2 +-
 trainer/src/MacaonTrain.cpp                         | 2 +-
 15 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 4e3ded1..e2dd7c9 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -101,7 +101,8 @@ int Classifier::getNbParameters() const
   int nbParameters = 0;
 
   for (auto & t : nn->parameters())
-    nbParameters += torch::numel(t);
+    if (t.requires_grad())
+      nbParameters += torch::numel(t);
 
   return nbParameters;
 }
diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp
index 4c7de25..b6134b7 100644
--- a/torch_modules/include/Concat.hpp
+++ b/torch_modules/include/Concat.hpp
@@ -9,10 +9,12 @@ class ConcatImpl : public MyModule
   private :
 
   int inputSize;
+  int outputSize;
+  torch::nn::Linear dimReduce{nullptr};
 
   public :
 
-  ConcatImpl(int inputSize);
+  ConcatImpl(int inputSize, int outputSize);
   torch::Tensor forward(torch::Tensor input);
   int getOutputSize(int sequenceLength);
 };
diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp
index 09d99c6..7ba200c 100644
--- a/torch_modules/src/Concat.cpp
+++ b/torch_modules/src/Concat.cpp
@@ -1,16 +1,17 @@
 #include "Concat.hpp"
 
-ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
+ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize)
 {
+  dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
 }
 
 torch::Tensor ConcatImpl::forward(torch::Tensor input)
 {
-  return input.view({input.size(0), -1});
+  return dimReduce(input).view({input.size(0), -1});
 }
 
 int ConcatImpl::getOutputSize(int sequenceLength)
 {
-  return sequenceLength * inputSize;
+  return sequenceLength * outputSize;
 }
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index b99f6ea..b1c7fd2 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else if (subModuleType == "Transformer")
               myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options));
             else
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 6992524..bd825f7 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index a60433e..acc45d5 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
               else if (subModuleType == "GRU")
                 depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
               else if (subModuleType == "Concat")
-                depthModules.emplace_back(register_module(name, Concat(inSize)));
+                depthModules.emplace_back(register_module(name, Concat(inSize, outSize)));
               else
                 util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
             }
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index a51eea0..0aebe58 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 107e956..62da3de 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp
index 25bfcc1..7d1c6f5 100644
--- a/torch_modules/src/HistoryMineModule.cpp
+++ b/torch_modules/src/HistoryMineModule.cpp
@@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string
             else if (subModuleType == "CNN")
               myModule = register_module("myModule", CNN(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index dddfdf7..c897364 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
             else if (subModuleType == "CNN")
               myModule = register_module("myModule", CNN(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index a5001e7..a666fc3 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(1, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(1));
+              myModule = register_module("myModule", Concat(1,1));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index d237485..c5dc9a5 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else if (subModuleType == "Transformer")
               myModule = register_module("myModule", Transformer(inSize, outSize, options));
             else
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 5f361a0..ee3fa38 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(inSize));
+              myModule = register_module("myModule", Concat(inSize, outSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index ff6e0ac..8d86c74 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(1, outSize, options));
             else if (subModuleType == "Concat")
-              myModule = register_module("myModule", Concat(1));
+              myModule = register_module("myModule", Concat(1, 1));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 9846783..cef2bea 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -282,7 +282,7 @@ int MacaonTrain::main()
       {
         machine.resetClassifiers();
         machine.trainMode(currentEpoch == 0);
-        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
+        fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
       }
 
       machine.resetOptimizers();
-- 
GitLab