From ed05ee4ac47f827f31bd8de5b760c772106be90f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 5 Jun 2020 09:58:44 +0200 Subject: [PATCH] Added Concat module --- torch_modules/include/Concat.hpp | 22 +++++++++++++++++++ torch_modules/include/ContextModule.hpp | 1 + .../include/DepthLayerTreeEmbeddingModule.hpp | 1 + torch_modules/include/FocusedColumnModule.hpp | 1 + torch_modules/include/HistoryModule.hpp | 1 + torch_modules/include/NumericColumnModule.hpp | 1 + torch_modules/include/RawInputModule.hpp | 1 + torch_modules/include/SplitTransModule.hpp | 1 + torch_modules/include/UppercaseRateModule.hpp | 1 + torch_modules/src/Concat.cpp | 16 ++++++++++++++ torch_modules/src/ContextModule.cpp | 2 ++ .../src/DepthLayerTreeEmbeddingModule.cpp | 2 ++ torch_modules/src/FocusedColumnModule.cpp | 2 ++ torch_modules/src/HistoryModule.cpp | 2 ++ torch_modules/src/NumericColumnModule.cpp | 2 ++ torch_modules/src/RawInputModule.cpp | 2 ++ torch_modules/src/SplitTransModule.cpp | 2 ++ torch_modules/src/UppercaseRateModule.cpp | 2 ++ 18 files changed, 62 insertions(+) create mode 100644 torch_modules/include/Concat.hpp create mode 100644 torch_modules/src/Concat.cpp diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp new file mode 100644 index 0000000..4c7de25 --- /dev/null +++ b/torch_modules/include/Concat.hpp @@ -0,0 +1,22 @@ +#ifndef Concat__H +#define Concat__H + +#include <torch/torch.h> +#include "MyModule.hpp" + +class ConcatImpl : public MyModule +{ + private : + + int inputSize; + + public : + + ConcatImpl(int inputSize); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(int sequenceLength); +}; +TORCH_MODULE(Concat); + +#endif + diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 3ab3895..b2f33cf 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "GRU.hpp" #include "LSTM.hpp" +#include "Concat.hpp" class ContextModuleImpl : public Submodule { diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index c3d8ce3..277f7fb 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class DepthLayerTreeEmbeddingModuleImpl : public Submodule { diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 05da795..cfd9c32 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class FocusedColumnModuleImpl : public Submodule { diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 3d9b2ff..594df1f 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class HistoryModuleImpl : public Submodule { diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 16348b9..82e3d37 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class NumericColumnModuleImpl : public Submodule { diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index c78ac8c..d3a0e6c 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class RawInputModuleImpl : public Submodule { diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 643ee71..f738cdd 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class SplitTransModuleImpl : public Submodule { diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index 4256e06..e28366e 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class UppercaseRateModuleImpl : public Submodule { diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp new file mode 100644 index 0000000..09d99c6 --- /dev/null +++ b/torch_modules/src/Concat.cpp @@ -0,0 +1,16 @@ +#include "Concat.hpp" + +ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize) +{ +} + +torch::Tensor ConcatImpl::forward(torch::Tensor input) +{ + return input.view({input.size(0), -1}); +} + +int ConcatImpl::getOutputSize(int sequenceLength) +{ + return sequenceLength * inputSize; +} + diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index f9c1c84..21723b1 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0d8111e..4894eb9 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options))); else if (subModuleType == "GRU") depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); + else if (subModuleType == "Concat") + depthModules.emplace_back(register_module(name, Concat(inSize))); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 9f7f766..5ac927a 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 1f0fa52..be36990 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 5f8c8d4..c94ac66 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index ae6fd80..14cd3bc 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 7994f2d..45c268a 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 7f92e05..478651c 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} -- GitLab