diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c7de25154e933448394c15add2f6e052c1f85ff --- /dev/null +++ b/torch_modules/include/Concat.hpp @@ -0,0 +1,22 @@ +#ifndef Concat__H +#define Concat__H + +#include <torch/torch.h> +#include "MyModule.hpp" + +class ConcatImpl : public MyModule +{ + private : + + int inputSize; + + public : + + ConcatImpl(int inputSize); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(int sequenceLength); +}; +TORCH_MODULE(Concat); + +#endif + diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 3ab3895ffd19a1899db0d98b588a7f13d365a6ae..b2f33cfed187c0f910cfafa2baea78102605aba5 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "GRU.hpp" #include "LSTM.hpp" +#include "Concat.hpp" class ContextModuleImpl : public Submodule { diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index c3d8ce31f818a3bca28ce3eceb2855dd77e15637..277f7fbe0d0da50bdc97bc4099a2f5678d594adb 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class DepthLayerTreeEmbeddingModuleImpl : public Submodule { diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 05da7956dfcbbfde3466fc5d1dde19589e1e0d15..cfd9c32cf0823704f6f16767cf9ad29745b67594 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class FocusedColumnModuleImpl : public Submodule { diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 3d9b2ff5cf8961893fe2c9da859cfd12cf05c7b3..594df1fd6ebba6b4fd939427d7379c9e50fde568 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class HistoryModuleImpl : public Submodule { diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 16348b9c028530e8cf429c1a98e2857b6e9bc32d..82e3d37b2aba01492269687afad9497459e93b27 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class NumericColumnModuleImpl : public Submodule { diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index c78ac8ce7dc1cfd87c7f0f887c15e1508fb987d6..d3a0e6ca42a8a5020ffc008bca746a5808458b83 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class RawInputModuleImpl : public Submodule { diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 643ee7178da5b2b33c52aca03918a3b066977234..f738cddb99760da4bf8bb892a1259bb0d45153a2 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class SplitTransModuleImpl : public Submodule { diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index 4256e06ae6cd140b445541d6dbd3fea1ab550532..e28366e7786660e1b3136054333f73411fac6d9a 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "Concat.hpp" class UppercaseRateModuleImpl : public Submodule { diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09d99c64797133e7778a6dfaf055f3faab6a7b40 --- /dev/null +++ b/torch_modules/src/Concat.cpp @@ -0,0 +1,16 @@ +#include "Concat.hpp" + +ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize) +{ +} + +torch::Tensor ConcatImpl::forward(torch::Tensor input) +{ + return input.view({input.size(0), -1}); +} + +int ConcatImpl::getOutputSize(int sequenceLength) +{ + return sequenceLength * inputSize; +} + diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index f9c1c8455fdc28b597c6c568bee6b097c0f3ee24..21723b14d6af1c3ae2e76bf3e899bf3b513e28c8 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 0d8111e0630794fa0bbf1b652ea6c254aad0a112..4894eb94d8d2023d35707baf601fca1d0ca69e5c 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options))); else if (subModuleType == "GRU") depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); + else if (subModuleType == "Concat") + depthModules.emplace_back(register_module(name, Concat(inSize))); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 9f7f7660e217ee905e8ce86d8b77fc1b99cd704e..5ac927ab245522cc9fee405949515fbc0a1ab730 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 1f0fa5277ceff54ccfabd9a6a7fd3c2daeddf116..be3699030f26ca8567d88a607d5a40dc8f0caf7c 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 5f8c8d4740148635e5eae55d6d0feac990111923..c94ac6605f0424edd6d93915397738075e06f204 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index ae6fd80377aeaddbf888c5913f43e39c51c8826a..14cd3bc2eabdd6236bd00ba954ce85997d53727b 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 7994f2da89c692d33b87f4a63a5a49dd5b989d63..45c268a3b59bb0821651f7e9b04e7d4decf13601 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 7f92e05465803b014ddc43c93ad1c399081d5355..478651c6f399d47698e26440f562a9834afcfe6f 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}