From ed05ee4ac47f827f31bd8de5b760c772106be90f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 5 Jun 2020 09:58:44 +0200
Subject: [PATCH] Added Concat module

---
 torch_modules/include/Concat.hpp              | 22 +++++++++++++++++++
 torch_modules/include/ContextModule.hpp       |  1 +
 .../include/DepthLayerTreeEmbeddingModule.hpp |  1 +
 torch_modules/include/FocusedColumnModule.hpp |  1 +
 torch_modules/include/HistoryModule.hpp       |  1 +
 torch_modules/include/NumericColumnModule.hpp |  1 +
 torch_modules/include/RawInputModule.hpp      |  1 +
 torch_modules/include/SplitTransModule.hpp    |  1 +
 torch_modules/include/UppercaseRateModule.hpp |  1 +
 torch_modules/src/Concat.cpp                  | 16 ++++++++++++++
 torch_modules/src/ContextModule.cpp           |  2 ++
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  2 ++
 torch_modules/src/FocusedColumnModule.cpp     |  2 ++
 torch_modules/src/HistoryModule.cpp           |  2 ++
 torch_modules/src/NumericColumnModule.cpp     |  2 ++
 torch_modules/src/RawInputModule.cpp          |  2 ++
 torch_modules/src/SplitTransModule.cpp        |  2 ++
 torch_modules/src/UppercaseRateModule.cpp     |  2 ++
 18 files changed, 62 insertions(+)
 create mode 100644 torch_modules/include/Concat.hpp
 create mode 100644 torch_modules/src/Concat.cpp

diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp
new file mode 100644
index 0000000..4c7de25
--- /dev/null
+++ b/torch_modules/include/Concat.hpp
@@ -0,0 +1,22 @@
+#ifndef Concat__H
+#define Concat__H
+
+#include <torch/torch.h>
+#include "MyModule.hpp"
+
+class ConcatImpl : public MyModule
+{
+  private :
+
+  int inputSize;
+
+  public :
+
+  ConcatImpl(int inputSize);
+  torch::Tensor forward(torch::Tensor input);
+  int getOutputSize(int sequenceLength);
+};
+TORCH_MODULE(Concat);
+
+#endif
+
diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 3ab3895..b2f33cf 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "GRU.hpp"
 #include "LSTM.hpp"
+#include "Concat.hpp"
 
 class ContextModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index c3d8ce3..277f7fb 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class DepthLayerTreeEmbeddingModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 05da795..cfd9c32 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class FocusedColumnModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 3d9b2ff..594df1f 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class HistoryModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index 16348b9..82e3d37 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class NumericColumnModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index c78ac8c..d3a0e6c 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class RawInputModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 643ee71..f738cdd 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class SplitTransModuleImpl : public Submodule
 {
diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp
index 4256e06..e28366e 100644
--- a/torch_modules/include/UppercaseRateModule.hpp
+++ b/torch_modules/include/UppercaseRateModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "Concat.hpp"
 
 class UppercaseRateModuleImpl : public Submodule
 {
diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp
new file mode 100644
index 0000000..09d99c6
--- /dev/null
+++ b/torch_modules/src/Concat.cpp
@@ -0,0 +1,16 @@
+#include "Concat.hpp"
+
+ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
+{
+}
+
+torch::Tensor ConcatImpl::forward(torch::Tensor input)
+{
+  return input.view({input.size(0), -1});
+}
+
+int ConcatImpl::getOutputSize(int sequenceLength)
+{
+  return sequenceLength * inputSize;
+}
+
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index f9c1c84..21723b1 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
               myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 0d8111e..4894eb9 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
                 depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
               else if (subModuleType == "GRU")
                 depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
+              else if (subModuleType == "Concat")
+                depthModules.emplace_back(register_module(name, Concat(inSize)));
               else
                 util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
             }
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 9f7f766..5ac927a 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 1f0fa52..be36990 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 5f8c8d4..c94ac66 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
               myModule = register_module("myModule", LSTM(1, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(1, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(1));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index ae6fd80..14cd3bc 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 7994f2d..45c268a 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(inSize));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 7f92e05..478651c 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
               myModule = register_module("myModule", LSTM(1, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(1, outSize, options));
+            else if (subModuleType == "Concat")
+              myModule = register_module("myModule", Concat(1));
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
-- 
GitLab