Commit b4228d7b authored by Franck Dary's avatar Franck Dary
Browse files

First draft of modular neural network

parent 5a2ea279
......@@ -19,6 +19,8 @@ class Classifier
void initNeuralNetwork(const std::vector<std::string> & definition);
void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
public :
......
#include "Classifier.hpp"
#include "util.hpp"
#include "LSTMNetwork.hpp"
#include "GRUNetwork.hpp"
#include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
{
......@@ -87,6 +89,10 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
else if (networkType == "LSTM")
initLSTM(definition, curIndex, nbOutputsPerState);
else if (networkType == "GRU")
initGRU(definition, curIndex, nbOutputsPerState);
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState);
else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
......@@ -349,3 +355,241 @@ void Classifier::setState(const std::string & state)
nn->setState(state);
}
void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
int unknownValueThreshold;
std::vector<int> bufferContext, stackContext;
std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
std::vector<int> focusedBuffer, focusedStack;
std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
std::vector<int> maxNbElements;
std::vector<int> treeEmbeddingNbElems;
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
bool bilstm, drop2d;
float lstmDropout, embeddingsDropout, totalInputDropout;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
{
unknownValueThreshold = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
bufferContext.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
stackContext.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
{
columns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
focusedBuffer.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
focusedStack.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
{
focusedColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
maxNbElements.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
{
rawInputLeftWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
{
rawInputRightWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
{
embeddingsSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
{
auto params = util::split(sm.str(1), ' ');
if (params.size() % 2)
util::myThrow("MLP must have even number of parameters");
for (unsigned int i = 0; i < params.size()/2; i++)
mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
{
contextLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
{
focusedLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
{
rawInputLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
{
splitTransLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
{
nbLayers = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
{
bilstm = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
{
lstmDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
{
totalInputDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
{
embeddingsDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
{
drop2d = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
{
treeEmbeddingColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingBuffer.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingStack.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingNbElems.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
{
treeEmbeddingSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
std::vector<std::string> modulesDefinitions;
for (; curIndex < definition.size(); curIndex++)
{
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
{
curIndex++;
break;
}
modulesDefinitions.emplace_back(definition[curIndex]);
}
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
}
#ifndef CONTEXTLSTM__H
#define CONTEXTLSTM__H
#ifndef CONTEXTMODULE__H
#define CONTEXTMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp"
class ContextLSTMImpl : public torch::nn::Module, public Submodule
class ContextModuleImpl : public Submodule
{
private :
LSTM lstm{nullptr};
torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::vector<int> bufferContext;
std::vector<int> stackContext;
......@@ -18,13 +21,13 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
public :
ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
ContextModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(ContextLSTM);
TORCH_MODULE(ContextModule);
#endif
......@@ -3,9 +3,11 @@
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
class DepthLayerTreeEmbeddingModule : public Submodule
{
private :
......@@ -13,17 +15,16 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
std::vector<std::string> columns;
std::vector<int> focusedBuffer;
std::vector<int> focusedStack;
std::vector<LSTM> depthLstm;
std::vector<std::shared_ptr<MyModule>> depthModules;
public :
DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options);
DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(DepthLayerTreeEmbedding);
#endif
#ifndef FOCUSEDCOLUMNLSTM__H
#define FOCUSEDCOLUMNLSTM__H
#ifndef FOCUSEDCOLUMNMODULE__H
#define FOCUSEDCOLUMNMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
class FocusedColumnModule : public Submodule
{
private :
LSTM lstm{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> focusedBuffer, focusedStack;
std::string column;
int maxNbElements;
public :
FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(FocusedColumnLSTM);
#endif
#ifndef GRU__H
#define GRU__H
#include <torch/torch.h>
#include "MyModule.hpp"
class GRUImpl : public MyModule
{
private :
torch::nn::GRU gru{nullptr};
bool outputAll;
public :
GRUImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(GRU);
#endif
#ifndef GRUNETWORK__H
#define GRUNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class GRUNetworkImpl : public NeuralNetworkImpl
{
// private :
//
// torch::nn::Embedding wordEmbeddings{nullptr};
// torch::nn::Dropout2d embeddingsDropout2d{nullptr};
// torch::nn::Dropout embeddingsDropout{nullptr};
// torch::nn::Dropout inputDropout{nullptr};
//
// MLP mlp{nullptr};
// ContextModule contextGRU{nullptr};
// RawInputModule rawInputGRU{nullptr};
// SplitTransModule splitTransGRU{nullptr};
// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
// std::vector<FocusedColumnModule> focusedLstms;
// std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
......@@ -2,14 +2,10 @@
#define LSTM__H
#include <torch/torch.h>
#include "fmt/core.h"
#include "MyModule.hpp"
class LSTMImpl : public torch::nn::Module
class LSTMImpl : public MyModule
{
public :
using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
private :
torch::nn::LSTM lstm{nullptr};
......@@ -17,7 +13,7 @@ class LSTMImpl : public torch::nn::Module
public :
LSTMImpl(int inputSize, int outputSize, LSTMOptions options);
LSTMImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
......
......@@ -2,29 +2,29 @@
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextLSTM.hpp"
#include "RawInputLSTM.hpp"
#include "SplitTransLSTM.hpp"
#include "FocusedColumnLSTM.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
#include "DepthLayerTreeEmbedding.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr};
MLP mlp{nullptr};
ContextLSTM contextLSTM{nullptr};
RawInputLSTM rawInputLSTM{nullptr};
SplitTransLSTM splitTransLSTM{nullptr};
DepthLayerTreeEmbedding treeEmbedding{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms;
std::map<std::string,torch::nn::Linear> outputLayersPerState;
// private :
//
// torch::nn::Embedding wordEmbeddings{nullptr};
// torch::nn::Dropout2d embeddingsDropout2d{nullptr};
// torch::nn::Dropout embeddingsDropout{nullptr};
// torch::nn::Dropout inputDropout{nullptr};
//
// MLP mlp{nullptr};
// ContextModule contextLSTM{nullptr};
// RawInputModule rawInputLSTM{nullptr};
// SplitTransModule splitTransLSTM{nullptr};
// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
// std::vector<FocusedColumnModule> focusedLstms;
// std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
......
......@@ -13,7 +13,7 @@ class MLPImpl : public torch::nn::Module
public :
MLPImpl(int inputSize, std::vector<std::pair<int, float>> params);
MLPImpl(int inputSize, std::string definition);
torch::Tensor forward(torch::Tensor input);
std::size_t outputSize() const;
};
......
#ifndef MODULARNETWORK__H
#define MODULARNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr};
MLP mlp{nullptr};
std::vector<std::shared_ptr<Submodule>> modules;
std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};