Commit ca72f53e authored by Franck Dary's avatar Franck Dary
Browse files

Implemented all modules

parent b4228d7b
......@@ -18,8 +18,6 @@ class Classifier
private :
void initNeuralNetwork(const std::vector<std::string> & definition);
void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
public :
......
This diff is collapsed.
......@@ -7,7 +7,7 @@
#include "LSTM.hpp"
#include "GRU.hpp"
class DepthLayerTreeEmbeddingModule : public Submodule
class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{
private :
......@@ -15,16 +15,18 @@ class DepthLayerTreeEmbeddingModule : public Submodule
std::vector<std::string> columns;
std::vector<int> focusedBuffer;
std::vector<int> focusedStack;
torch::nn::Embedding wordEmbeddings{nullptr};
std::vector<std::shared_ptr<MyModule>> depthModules;
public :
DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options);
DepthLayerTreeEmbeddingModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
#endif
......@@ -7,10 +7,11 @@
#include "LSTM.hpp"
#include "GRU.hpp"
class FocusedColumnModule : public Submodule
class FocusedColumnModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> focusedBuffer, focusedStack;
std::string column;
......@@ -18,12 +19,13 @@ class FocusedColumnModule : public Submodule
public :
FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
FocusedColumnModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(FocusedColumnModule);
#endif
#ifndef GRUNETWORK__H
#define GRUNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class GRUNetworkImpl : public NeuralNetworkImpl
{
// private :
//
// torch::nn::Embedding wordEmbeddings{nullptr};
// torch::nn::Dropout2d embeddingsDropout2d{nullptr};
// torch::nn::Dropout embeddingsDropout{nullptr};
// torch::nn::Dropout inputDropout{nullptr};
//
// MLP mlp{nullptr};
// ContextModule contextGRU{nullptr};
// RawInputModule rawInputGRU{nullptr};
// SplitTransModule splitTransGRU{nullptr};
// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
// std::vector<FocusedColumnModule> focusedLstms;
// std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
#ifndef LSTMNETWORK__H
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
// private :
//
// torch::nn::Embedding wordEmbeddings{nullptr};
// torch::nn::Dropout2d embeddingsDropout2d{nullptr};
// torch::nn::Dropout embeddingsDropout{nullptr};
// torch::nn::Dropout inputDropout{nullptr};
//
// MLP mlp{nullptr};
// ContextModule contextLSTM{nullptr};
// RawInputModule rawInputLSTM{nullptr};
// SplitTransModule splitTransLSTM{nullptr};
// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
// std::vector<FocusedColumnModule> focusedLstms;
// std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
......@@ -13,9 +13,9 @@ class ModularNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
//torch::nn::Embedding wordEmbeddings{nullptr};
//torch::nn::Dropout2d embeddingsDropout2d{nullptr};
//torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr};
MLP mlp{nullptr};
......
......@@ -7,21 +7,23 @@
#include "LSTM.hpp"
#include "GRU.hpp"
class RawInputModule : public Submodule
class RawInputModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int leftWindow, rightWindow;
public :
RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
RawInputModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(RawInputModule);
#endif
......@@ -7,21 +7,23 @@
#include "LSTM.hpp"
#include "GRU.hpp"
class SplitTransModule : public Submodule
class SplitTransModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbTrans;
public :
SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
SplitTransModuleImpl(int maxNbTrans, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(SplitTransModule);
#endif
#include "DepthLayerTreeEmbeddingModule.hpp"
DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) :
maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::string & definition)
{
for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
columns = util::split(sm.str(1), ' ');
for (auto & index : util::split(sm.str(2), ' '))
focusedBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(3), ' '))
focusedStack.emplace_back(std::stoi(index));
for (auto & elem : util::split(sm.str(4), ' '))
maxElemPerDepth.emplace_back(std::stoi(elem));
auto subModuleType = sm.str(5);
auto subModuleArguments = util::split(sm.str(6), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int inSize = std::stoi(sm.str(7));
int outSize = std::stoi(sm.str(8));
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
{
std::string name = fmt::format("{}_{}", i, subModuleType);
if (subModuleType == "LSTM")
depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
std::vector<torch::Tensor> outputs;
......@@ -17,14 +58,14 @@ torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{
outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)})));
offset += maxElemPerDepth[depth]*columns.size();
}
return torch::cat(outputs, 1);
}
std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize()
{
std::size_t outputSize = 0;
......@@ -34,7 +75,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
return outputSize*(focusedBuffer.size()+focusedStack.size());
}
std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
{
int inputSize = 0;
for (int maxElem : maxElemPerDepth)
......@@ -42,7 +83,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
return inputSize;
}
void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
#include "FocusedColumnModule.hpp"
FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition)
{
myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
column = sm.str(1);
maxNbElements = std::stoi(sm.str(2));
for (auto & index : util::split(sm.str(3), ' '))
focusedBuffer.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(4), ' '))
focusedStack.emplace_back(std::stoi(index));
auto subModuleType = sm.str(5);
auto subModuleArguments = util::split(sm.str(6), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int inSize = std::stoi(sm.str(7));
int outSize = std::stoi(sm.str(8));
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor FocusedColumnModule::forward(torch::Tensor input)
torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))));
return torch::cat(outputs, 1);
}
std::size_t FocusedColumnModule::getOutputSize()
std::size_t FocusedColumnModuleImpl::getOutputSize()
{
return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements);
}
std::size_t FocusedColumnModule::getInputSize()
std::size_t FocusedColumnModuleImpl::getInputSize()
{
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
}
void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
#include "GRUNetwork.hpp"
GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{
// MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false};
// auto gruOptionsAll = gruOptions;
// std::get<4>(gruOptionsAll) = true;
//
// int currentOutputSize = embeddingsSize;
// int currentInputSize = 1;
//
// contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold));
// contextGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += contextGRU->getOutputSize();
// currentInputSize += contextGRU->getInputSize();
//
// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
// {
// rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll));
// rawInputGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += rawInputGRU->getOutputSize();
// currentInputSize += rawInputGRU->getInputSize();
// }
//
// if (!treeEmbeddingColumns.empty())
// {
// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions));
// treeEmbedding->setFirstInputIndex(currentInputSize);
// currentOutputSize += treeEmbedding->getOutputSize();
// currentInputSize += treeEmbedding->getInputSize();
// }
//
// splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll));
// splitTransGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += splitTransGRU->getOutputSize();
// currentInputSize += splitTransGRU->getInputSize();
//
// for (unsigned int i = 0; i < focusedColumns.size(); i++)
// {
// focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions)));
// focusedLstms.back()->setFirstInputIndex(currentInputSize);
// currentOutputSize += focusedLstms.back()->getOutputSize();
// currentInputSize += focusedLstms.back()->getInputSize();
// }
//
// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
// if (drop2d)
// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
// else
// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
//
// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
//
// for (auto & it : nbOutputsPerState)
// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor GRUNetworkImpl::forward(torch::Tensor input)
{
return input;
// if (input.dim() == 1)
// input = input.unsqueeze(0);
//
// auto embeddings = wordEmbeddings(input);
// if (embeddingsDropout2d.is_empty())
// embeddings = embeddingsDropout(embeddings);
// else
// embeddings = embeddingsDropout2d(embeddings);
//
// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
//
// outputs.emplace_back(contextGRU(embeddings));
//
// if (!rawInputGRU.is_empty())
// outputs.emplace_back(rawInputGRU(embeddings));
//
// if (!treeEmbedding.is_empty())
// outputs.emplace_back(treeEmbedding(embeddings));
//
// outputs.emplace_back(splitTransGRU(embeddings));
//
// for (auto & gru : focusedLstms)
// outputs.emplace_back(gru(embeddings));
//
// auto totalInput = inputDropout(torch::cat(outputs, 1));
//
// return outputLayersPerState.at(getState())(mlp(totalInput));
}
std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<std::vector<long>> context;
return context;
// if (dict.size() >= maxNbEmbeddings)
// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
//
// context.emplace_back();
//
// context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
//
// contextGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!rawInputGRU.is_empty())
// rawInputGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!treeEmbedding.is_empty())
// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
//
// splitTransGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// for (auto & gru : focusedLstms)
// gru->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!mustSplitUnknown() && context.size() > 1)
// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
//
// return context;
}
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{
// MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false};
// auto moduleOptionsAll = moduleOptions;
// std::get<4>(moduleOptionsAll) = true;
//
// int currentOutputSize = embeddingsSize;
// int currentInputSize = 1;
//
// contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold));
// contextLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += contextLSTM->getOutputSize();
// currentInputSize += contextLSTM->getInputSize();
//
// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
// {
// rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll));
// rawInputLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += rawInputLSTM->getOutputSize();
// currentInputSize += rawInputLSTM->getInputSize();
// }
//
// if (!treeEmbeddingColumns.empty())
// {
// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions));
// treeEmbedding->setFirstInputIndex(currentInputSize);
// currentOutputSize += treeEmbedding->getOutputSize();
// currentInputSize += treeEmbedding->getInputSize();
// }
//
// splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll));
// splitTransLSTM->setFirstInputIndex(currentInputSize);
// currentOutputSize += splitTransLSTM->getOutputSize();
// currentInputSize += splitTransLSTM->getInputSize();
//
// for (unsigned int i = 0; i < focusedColumns.size(); i++)
// {
// focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions)));
// focusedLstms.back()->setFirstInputIndex(currentInputSize);
// currentOutputSize += focusedLstms.back()->getOutputSize();
// currentInputSize += focusedLstms.back()->getInputSize();
// }
//
// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
// if (drop2d)
// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
// else
// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
//
// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
//
// for (auto & it : nbOutputsPerState)
// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
{
return input;
// if (input.dim() == 1)
// input = input.unsqueeze(0);
//
// auto embeddings = wordEmbeddings(input);
// if (embeddingsDropout2d.is_empty())
// embeddings = embeddingsDropout(embeddings);
// else
// embeddings = embeddingsDropout2d(embeddings);