Skip to content
Snippets Groups Projects
Commit b4228d7b authored by Franck Dary's avatar Franck Dary
Browse files

First draft of modular neural network

parent 5a2ea279
Branches
No related tags found
No related merge requests found
Showing
with 693 additions and 74 deletions
...@@ -19,6 +19,8 @@ class Classifier ...@@ -19,6 +19,8 @@ class Classifier
void initNeuralNetwork(const std::vector<std::string> & definition); void initNeuralNetwork(const std::vector<std::string> & definition);
void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
public : public :
......
#include "Classifier.hpp" #include "Classifier.hpp"
#include "util.hpp" #include "util.hpp"
#include "LSTMNetwork.hpp" #include "LSTMNetwork.hpp"
#include "GRUNetwork.hpp"
#include "RandomNetwork.hpp" #include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition) Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
{ {
...@@ -87,6 +89,10 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -87,6 +89,10 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
else if (networkType == "LSTM") else if (networkType == "LSTM")
initLSTM(definition, curIndex, nbOutputsPerState); initLSTM(definition, curIndex, nbOutputsPerState);
else if (networkType == "GRU")
initGRU(definition, curIndex, nbOutputsPerState);
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState);
else else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
...@@ -349,3 +355,241 @@ void Classifier::setState(const std::string & state) ...@@ -349,3 +355,241 @@ void Classifier::setState(const std::string & state)
nn->setState(state); nn->setState(state);
} }
void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
int unknownValueThreshold;
std::vector<int> bufferContext, stackContext;
std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
std::vector<int> focusedBuffer, focusedStack;
std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
std::vector<int> maxNbElements;
std::vector<int> treeEmbeddingNbElems;
std::vector<std::pair<int, float>> mlp;
int rawInputLeftWindow, rawInputRightWindow;
int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
bool bilstm, drop2d;
float lstmDropout, embeddingsDropout, totalInputDropout;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
{
unknownValueThreshold = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
bufferContext.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
stackContext.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
{
columns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
focusedBuffer.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
focusedStack.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
{
focusedColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
maxNbElements.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
{
rawInputLeftWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
{
rawInputRightWindow = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
{
embeddingsSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
{
auto params = util::split(sm.str(1), ' ');
if (params.size() % 2)
util::myThrow("MLP must have even number of parameters");
for (unsigned int i = 0; i < params.size()/2; i++)
mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
{
contextLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
{
focusedLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
{
rawInputLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
{
splitTransLSTMSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
{
nbLayers = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
{
bilstm = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
{
lstmDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
{
totalInputDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
{
embeddingsDropout = std::stof(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
{
drop2d = sm.str(1) == "true";
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
{
treeEmbeddingColumns = util::split(sm.str(1), ' ');
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingBuffer.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingStack.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
{
for (auto & index : util::split(sm.str(1), ' '))
treeEmbeddingNbElems.emplace_back(std::stoi(index));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
{
treeEmbeddingSize = std::stoi(sm.str(1));
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
std::vector<std::string> modulesDefinitions;
for (; curIndex < definition.size(); curIndex++)
{
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
{
curIndex++;
break;
}
modulesDefinitions.emplace_back(definition[curIndex]);
}
this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
}
#ifndef CONTEXTLSTM__H #ifndef CONTEXTMODULE__H
#define CONTEXTLSTM__H #define CONTEXTMODULE__H
#include <torch/torch.h> #include <torch/torch.h>
#include "Submodule.hpp" #include "Submodule.hpp"
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
class ContextLSTMImpl : public torch::nn::Module, public Submodule class ContextModuleImpl : public Submodule
{ {
private : private :
LSTM lstm{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns; std::vector<std::string> columns;
std::vector<int> bufferContext; std::vector<int> bufferContext;
std::vector<int> stackContext; std::vector<int> stackContext;
...@@ -18,13 +21,13 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule ...@@ -18,13 +21,13 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
public : public :
ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold); ContextModuleImpl(const std::string & definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(ContextLSTM); TORCH_MODULE(ContextModule);
#endif #endif
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "Submodule.hpp" #include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp"
class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule class DepthLayerTreeEmbeddingModule : public Submodule
{ {
private : private :
...@@ -13,17 +15,16 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule ...@@ -13,17 +15,16 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
std::vector<std::string> columns; std::vector<std::string> columns;
std::vector<int> focusedBuffer; std::vector<int> focusedBuffer;
std::vector<int> focusedStack; std::vector<int> focusedStack;
std::vector<LSTM> depthLstm; std::vector<std::shared_ptr<MyModule>> depthModules;
public : public :
DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options); DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(DepthLayerTreeEmbedding);
#endif #endif
#ifndef FOCUSEDCOLUMNLSTM__H #ifndef FOCUSEDCOLUMNMODULE__H
#define FOCUSEDCOLUMNLSTM__H #define FOCUSEDCOLUMNMODULE__H
#include <torch/torch.h> #include <torch/torch.h>
#include "Submodule.hpp" #include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp"
class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule class FocusedColumnModule : public Submodule
{ {
private : private :
LSTM lstm{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> focusedBuffer, focusedStack; std::vector<int> focusedBuffer, focusedStack;
std::string column; std::string column;
int maxNbElements; int maxNbElements;
public : public :
FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(FocusedColumnLSTM);
#endif #endif
#ifndef GRU__H
#define GRU__H
#include <torch/torch.h>
#include "MyModule.hpp"
class GRUImpl : public MyModule
{
private :
torch::nn::GRU gru{nullptr};
bool outputAll;
public :
GRUImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(GRU);
#endif
#ifndef GRUNETWORK__H
#define GRUNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class GRUNetworkImpl : public NeuralNetworkImpl
{
// private :
//
// torch::nn::Embedding wordEmbeddings{nullptr};
// torch::nn::Dropout2d embeddingsDropout2d{nullptr};
// torch::nn::Dropout embeddingsDropout{nullptr};
// torch::nn::Dropout inputDropout{nullptr};
//
// MLP mlp{nullptr};
// ContextModule contextGRU{nullptr};
// RawInputModule rawInputGRU{nullptr};
// SplitTransModule splitTransGRU{nullptr};
// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
// std::vector<FocusedColumnModule> focusedLstms;
// std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
...@@ -2,14 +2,10 @@ ...@@ -2,14 +2,10 @@
#define LSTM__H #define LSTM__H
#include <torch/torch.h> #include <torch/torch.h>
#include "fmt/core.h" #include "MyModule.hpp"
class LSTMImpl : public torch::nn::Module class LSTMImpl : public MyModule
{ {
public :
using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
private : private :
torch::nn::LSTM lstm{nullptr}; torch::nn::LSTM lstm{nullptr};
...@@ -17,7 +13,7 @@ class LSTMImpl : public torch::nn::Module ...@@ -17,7 +13,7 @@ class LSTMImpl : public torch::nn::Module
public : public :
LSTMImpl(int inputSize, int outputSize, LSTMOptions options); LSTMImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength); int getOutputSize(int sequenceLength);
}; };
......
...@@ -2,29 +2,29 @@ ...@@ -2,29 +2,29 @@
#define LSTMNETWORK__H #define LSTMNETWORK__H
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
#include "ContextLSTM.hpp" #include "ContextModule.hpp"
#include "RawInputLSTM.hpp" #include "RawInputModule.hpp"
#include "SplitTransLSTM.hpp" #include "SplitTransModule.hpp"
#include "FocusedColumnLSTM.hpp" #include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp" #include "MLP.hpp"
#include "DepthLayerTreeEmbedding.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl class LSTMNetworkImpl : public NeuralNetworkImpl
{ {
private : // private :
//
torch::nn::Embedding wordEmbeddings{nullptr}; // torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr}; // torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr}; // torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr}; // torch::nn::Dropout inputDropout{nullptr};
//
MLP mlp{nullptr}; // MLP mlp{nullptr};
ContextLSTM contextLSTM{nullptr}; // ContextModule contextLSTM{nullptr};
RawInputLSTM rawInputLSTM{nullptr}; // RawInputModule rawInputLSTM{nullptr};
SplitTransLSTM splitTransLSTM{nullptr}; // SplitTransModule splitTransLSTM{nullptr};
DepthLayerTreeEmbedding treeEmbedding{nullptr}; // DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
std::vector<FocusedColumnLSTM> focusedLstms; // std::vector<FocusedColumnModule> focusedLstms;
std::map<std::string,torch::nn::Linear> outputLayersPerState; // std::map<std::string,torch::nn::Linear> outputLayersPerState;
public : public :
......
...@@ -13,7 +13,7 @@ class MLPImpl : public torch::nn::Module ...@@ -13,7 +13,7 @@ class MLPImpl : public torch::nn::Module
public : public :
MLPImpl(int inputSize, std::vector<std::pair<int, float>> params); MLPImpl(int inputSize, std::string definition);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t outputSize() const; std::size_t outputSize() const;
}; };
......
#ifndef MODULARNETWORK__H
#define MODULARNETWORK__H
#include "NeuralNetwork.hpp"
#include "ContextModule.hpp"
#include "RawInputModule.hpp"
#include "SplitTransModule.hpp"
#include "FocusedColumnModule.hpp"
#include "DepthLayerTreeEmbeddingModule.hpp"
#include "MLP.hpp"
class ModularNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout2d embeddingsDropout2d{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout inputDropout{nullptr};
MLP mlp{nullptr};
std::vector<std::shared_ptr<Submodule>> modules;
std::map<std::string,torch::nn::Linear> outputLayersPerState;
public :
ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif
#ifndef MYMODULE__H
#define MYMODULE__H
#include <torch/torch.h>
class MyModule : public torch::nn::Module
{
public :
struct ModuleOptions : std::tuple<bool,bool,int,float,bool>
{
ModuleOptions(bool batchFirst){std::get<0>(*this)=batchFirst;};
ModuleOptions & bidirectional(bool val) {std::get<1>(*this)=val; return *this;}
ModuleOptions & num_layers(int num) {std::get<2>(*this)=num; return *this;}
ModuleOptions & dropout(float val) {std::get<3>(*this)=val; return *this;}
ModuleOptions & complete(bool val) {std::get<4>(*this)=val; return *this;}
};
public :
virtual int getOutputSize(int sequenceLength) = 0;
virtual torch::Tensor forward(torch::Tensor) = 0;
};
#endif
#ifndef RAWINPUTLSTM__H #ifndef RAWINPUTMODULE__H
#define RAWINPUTLSTM__H #define RAWINPUTMODULE__H
#include <torch/torch.h> #include <torch/torch.h>
#include "Submodule.hpp" #include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp"
class RawInputLSTMImpl : public torch::nn::Module, public Submodule class RawInputModule : public Submodule
{ {
private : private :
LSTM lstm{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
int leftWindow, rightWindow; int leftWindow, rightWindow;
public : public :
RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(RawInputLSTM);
#endif #endif
#ifndef SPLITTRANSLSTM__H #ifndef SPLITTRANSMODULE__H
#define SPLITTRANSLSTM__H #define SPLITTRANSMODULE__H
#include <torch/torch.h> #include <torch/torch.h>
#include "Submodule.hpp" #include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp"
class SplitTransLSTMImpl : public torch::nn::Module, public Submodule class SplitTransModule : public Submodule
{ {
private : private :
LSTM lstm{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
int maxNbTrans; int maxNbTrans;
public : public :
SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(SplitTransLSTM);
#endif #endif
#ifndef SUBMODULE__H #ifndef SUBMODULE__H
#define SUBMODULE__H #define SUBMODULE__H
#include <torch/torch.h>
#include "Dict.hpp" #include "Dict.hpp"
#include "Config.hpp" #include "Config.hpp"
class Submodule class Submodule : public torch::nn::Module
{ {
protected : protected :
...@@ -16,6 +17,7 @@ class Submodule ...@@ -16,6 +17,7 @@ class Submodule
virtual std::size_t getOutputSize() = 0; virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0; virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0; virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
}; };
#endif #endif
......
#include "ContextLSTM.hpp" #include "ContextModule.hpp"
ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold) ContextModuleImpl::ContextModuleImpl(const std::string & definition)
{ {
lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)); std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
unknownValueThreshold = std::stoi(sm.str(1));
for (auto & index : util::split(sm.str(2), ' '))
bufferContext.emplace_back(std::stoi(index));
for (auto & index : util::split(sm.str(3), ' '))
stackContext.emplace_back(std::stoi(index));
columns = util::split(sm.str(4), ' ');
auto subModuleType = sm.str(5);
auto subModuleArguments = util::split(sm.str(6), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int inSize = std::stoi(sm.str(7));
int outSize = std::stoi(sm.str(8));
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
} }
std::size_t ContextLSTMImpl::getOutputSize() std::size_t ContextModuleImpl::getOutputSize()
{ {
return lstm->getOutputSize(bufferContext.size()+stackContext.size()); return myModule->getOutputSize(bufferContext.size()+stackContext.size());
} }
std::size_t ContextLSTMImpl::getInputSize() std::size_t ContextModuleImpl::getInputSize()
{ {
return columns.size()*(bufferContext.size()+stackContext.size()); return columns.size()*(bufferContext.size()+stackContext.size());
} }
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
{ {
std::vector<long> contextIndexes; std::vector<long> contextIndexes;
...@@ -49,12 +87,12 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic ...@@ -49,12 +87,12 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
} }
} }
torch::Tensor ContextLSTMImpl::forward(torch::Tensor input) torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
return lstm(context); return myModule->forward(context);
} }
#include "DepthLayerTreeEmbedding.hpp" #include "DepthLayerTreeEmbeddingModule.hpp"
DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) : DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) :
maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack) maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
{ {
for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options))); depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
} }
torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = input.narrow(1, firstInputIndex, getInputSize());
...@@ -17,24 +17,24 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) ...@@ -17,24 +17,24 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++) for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
{ {
outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)}))); outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
offset += maxElemPerDepth[depth]*columns.size(); offset += maxElemPerDepth[depth]*columns.size();
} }
return torch::cat(outputs, 1); return torch::cat(outputs, 1);
} }
std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize() std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
{ {
std::size_t outputSize = 0; std::size_t outputSize = 0;
for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]); outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]);
return outputSize*(focusedBuffer.size()+focusedStack.size()); return outputSize*(focusedBuffer.size()+focusedStack.size());
} }
std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
{ {
int inputSize = 0; int inputSize = 0;
for (int maxElem : maxElemPerDepth) for (int maxElem : maxElemPerDepth)
...@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() ...@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
return inputSize; return inputSize;
} }
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
std::vector<long> focusedIndexes; std::vector<long> focusedIndexes;
......
#include "FocusedColumnLSTM.hpp" #include "FocusedColumnModule.hpp"
FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements) FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
{ {
lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
} }
torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input) torch::Tensor FocusedColumnModule::forward(torch::Tensor input)
{ {
std::vector<torch::Tensor> outputs; std::vector<torch::Tensor> outputs;
for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++) for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))); outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
return torch::cat(outputs, 1); return torch::cat(outputs, 1);
} }
std::size_t FocusedColumnLSTMImpl::getOutputSize() std::size_t FocusedColumnModule::getOutputSize()
{ {
return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements); return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements);
} }
std::size_t FocusedColumnLSTMImpl::getInputSize() std::size_t FocusedColumnModule::getInputSize()
{ {
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
} }
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
std::vector<long> focusedIndexes; std::vector<long> focusedIndexes;
......
#include "GRU.hpp"
GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.num_layers(std::get<2>(options))
.dropout(std::get<3>(options));
gru = register_module("gru", torch::nn::GRU(gruOptions));
}
torch::Tensor GRUImpl::forward(torch::Tensor input)
{
auto gruOut = std::get<0>(gru(input));
if (outputAll)
return gruOut.reshape({gruOut.size(0), -1});
if (gru->options.bidirectional())
return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1);
return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1);
}
int GRUImpl::getOutputSize(int sequenceLength)
{
if (outputAll)
return sequenceLength * gru->options.hidden_size() * (gru->options.bidirectional() ? 2 : 1);
return gru->options.hidden_size() * (gru->options.bidirectional() ? 4 : 1);
}
#include "GRUNetwork.hpp"
GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
{
// MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false};
// auto gruOptionsAll = gruOptions;
// std::get<4>(gruOptionsAll) = true;
//
// int currentOutputSize = embeddingsSize;
// int currentInputSize = 1;
//
// contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold));
// contextGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += contextGRU->getOutputSize();
// currentInputSize += contextGRU->getInputSize();
//
// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
// {
// rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll));
// rawInputGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += rawInputGRU->getOutputSize();
// currentInputSize += rawInputGRU->getInputSize();
// }
//
// if (!treeEmbeddingColumns.empty())
// {
// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions));
// treeEmbedding->setFirstInputIndex(currentInputSize);
// currentOutputSize += treeEmbedding->getOutputSize();
// currentInputSize += treeEmbedding->getInputSize();
// }
//
// splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll));
// splitTransGRU->setFirstInputIndex(currentInputSize);
// currentOutputSize += splitTransGRU->getOutputSize();
// currentInputSize += splitTransGRU->getInputSize();
//
// for (unsigned int i = 0; i < focusedColumns.size(); i++)
// {
// focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions)));
// focusedLstms.back()->setFirstInputIndex(currentInputSize);
// currentOutputSize += focusedLstms.back()->getOutputSize();
// currentInputSize += focusedLstms.back()->getInputSize();
// }
//
// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
// if (drop2d)
// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
// else
// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
//
// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
//
// for (auto & it : nbOutputsPerState)
// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
}
torch::Tensor GRUNetworkImpl::forward(torch::Tensor input)
{
return input;
// if (input.dim() == 1)
// input = input.unsqueeze(0);
//
// auto embeddings = wordEmbeddings(input);
// if (embeddingsDropout2d.is_empty())
// embeddings = embeddingsDropout(embeddings);
// else
// embeddings = embeddingsDropout2d(embeddings);
//
// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
//
// outputs.emplace_back(contextGRU(embeddings));
//
// if (!rawInputGRU.is_empty())
// outputs.emplace_back(rawInputGRU(embeddings));
//
// if (!treeEmbedding.is_empty())
// outputs.emplace_back(treeEmbedding(embeddings));
//
// outputs.emplace_back(splitTransGRU(embeddings));
//
// for (auto & gru : focusedLstms)
// outputs.emplace_back(gru(embeddings));
//
// auto totalInput = inputDropout(torch::cat(outputs, 1));
//
// return outputLayersPerState.at(getState())(mlp(totalInput));
}
std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const
{
std::vector<std::vector<long>> context;
return context;
// if (dict.size() >= maxNbEmbeddings)
// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
//
// context.emplace_back();
//
// context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
//
// contextGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!rawInputGRU.is_empty())
// rawInputGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!treeEmbedding.is_empty())
// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
//
// splitTransGRU->addToContext(context, dict, config, mustSplitUnknown());
//
// for (auto & gru : focusedLstms)
// gru->addToContext(context, dict, config, mustSplitUnknown());
//
// if (!mustSplitUnknown() && context.size() > 1)
// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
//
// return context;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment