Commit ed05ee4a authored by Franck Dary's avatar Franck Dary
Browse files

Added Concat module

parent f799c58f
#ifndef Concat__H
#define Concat__H
#include <torch/torch.h>
#include "MyModule.hpp"
class ConcatImpl : public MyModule
{
private :
int inputSize;
public :
ConcatImpl(int inputSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(Concat);
#endif
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"
class ContextModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class FocusedColumnModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class HistoryModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class NumericColumnModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class RawInputModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class SplitTransModuleImpl : public Submodule
{
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
class UppercaseRateModuleImpl : public Submodule
{
......
#include "Concat.hpp"
ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
{
}
torch::Tensor ConcatImpl::forward(torch::Tensor input)
{
return input.view({input.size(0), -1});
}
int ConcatImpl::getOutputSize(int sequenceLength)
{
return sequenceLength * inputSize;
}
......@@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else if (subModuleType == "Concat")
depthModules.emplace_back(register_module(name, Concat(inSize)));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
}
......
......@@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(1, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
......@@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
myModule = register_module("myModule", LSTM(1, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment