Skip to content
Snippets Groups Projects
Commit ed05ee4a authored by Franck Dary's avatar Franck Dary
Browse files

Added Concat module

parent f799c58f
No related branches found
No related tags found
No related merge requests found
Showing
with 62 additions and 0 deletions
#ifndef Concat__H
#define Concat__H
#include <torch/torch.h>
#include "MyModule.hpp"
class ConcatImpl : public MyModule
{
private :
int inputSize;
public :
ConcatImpl(int inputSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(Concat);
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "Concat.hpp"
class ContextModuleImpl : public Submodule class ContextModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class DepthLayerTreeEmbeddingModuleImpl : public Submodule class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class FocusedColumnModuleImpl : public Submodule class FocusedColumnModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class HistoryModuleImpl : public Submodule class HistoryModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class NumericColumnModuleImpl : public Submodule class NumericColumnModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class RawInputModuleImpl : public Submodule class RawInputModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class SplitTransModuleImpl : public Submodule class SplitTransModuleImpl : public Submodule
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp"
class UppercaseRateModuleImpl : public Submodule class UppercaseRateModuleImpl : public Submodule
{ {
......
#include "Concat.hpp"
ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
{
}
torch::Tensor ConcatImpl::forward(torch::Tensor input)
{
return input.view({input.size(0), -1});
}
int ConcatImpl::getOutputSize(int sequenceLength)
{
return sequenceLength * inputSize;
}
...@@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin ...@@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string ...@@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options))); depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else if (subModuleType == "Concat")
depthModules.emplace_back(register_module(name, Concat(inSize)));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} }
......
...@@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st ...@@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin ...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st ...@@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(1, outSize, options)); myModule = register_module("myModule", LSTM(1, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options)); myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
...@@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def ...@@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con ...@@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st ...@@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(1, outSize, options)); myModule = register_module("myModule", LSTM(1, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options)); myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment