diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 4951e5dce72e57b6d02f74598e426c248ce0ddcb..0684d56fcad4bf1e2866103378bc4e9878354667 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -19,6 +19,8 @@ class Classifier void initNeuralNetwork(const std::vector<std::string> & definition); void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); + void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); + void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); public : diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index d9ead9a52e3dbf7e0beebd6ff08bc515712862a9..113ee9bda8932ccb86e95e8521eb5292dc5d429f 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -1,7 +1,9 @@ #include "Classifier.hpp" #include "util.hpp" #include "LSTMNetwork.hpp" +#include "GRUNetwork.hpp" #include "RandomNetwork.hpp" +#include "ModularNetwork.hpp" Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition) { @@ -87,6 +89,10 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); else if (networkType == "LSTM") initLSTM(definition, curIndex, nbOutputsPerState); + else if (networkType == "GRU") + initGRU(definition, curIndex, nbOutputsPerState); + else if (networkType == "Modular") + initModular(definition, curIndex, nbOutputsPerState); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); @@ -349,3 +355,241 @@ void Classifier::setState(const std::string & state) nn->setState(state); } +void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) +{ + int unknownValueThreshold; + std::vector<int> bufferContext, stackContext; + std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns; + std::vector<int> focusedBuffer, focusedStack; + std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack; + std::vector<int> maxNbElements; + std::vector<int> treeEmbeddingNbElems; + std::vector<std::pair<int, float>> mlp; + int rawInputLeftWindow, rawInputRightWindow; + int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; + bool bilstm, drop2d; + float lstmDropout, embeddingsDropout, totalInputDropout; + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) + { + unknownValueThreshold = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + bufferContext.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + stackContext.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm) + { + columns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + focusedBuffer.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + focusedStack.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm) + { + focusedColumns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + maxNbElements.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm) + { + rawInputLeftWindow = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm) + { + rawInputRightWindow = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm) + { + embeddingsSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm) + { + auto params = util::split(sm.str(1), ' '); + if (params.size() % 2) + util::myThrow("MLP must have even number of parameters"); + for (unsigned int i = 0; i < params.size()/2; i++) + mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1]))); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm) + { + contextLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm) + { + focusedLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm) + { + rawInputLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm) + { + splitTransLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm) + { + nbLayers = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm) + { + bilstm = sm.str(1) == "true"; + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm) + { + lstmDropout = std::stof(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm) + { + totalInputDropout = std::stof(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm) + { + embeddingsDropout = std::stof(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm) + { + drop2d = sm.str(1) == "true"; + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) + { + treeEmbeddingColumns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingBuffer.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingStack.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + treeEmbeddingNbElems.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm) + { + treeEmbeddingSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); + + this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); +} + +void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) +{ + std::string anyBlanks = "(?:(?:\\s|\\t)*)"; + std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks)); + std::vector<std::string> modulesDefinitions; + + for (; curIndex < definition.size(); curIndex++) + { + if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){})) + { + curIndex++; + break; + } + modulesDefinitions.emplace_back(definition[curIndex]); + } + + this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions)); +} + diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextModule.hpp similarity index 59% rename from torch_modules/include/ContextLSTM.hpp rename to torch_modules/include/ContextModule.hpp index 3e3bbacac0e56cfd38e981279a0f6a54c1f41b3d..a9b609034023857a245c985990cc37f1d06f2bb7 100644 --- a/torch_modules/include/ContextLSTM.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -1,15 +1,18 @@ -#ifndef CONTEXTLSTM__H -#define CONTEXTLSTM__H +#ifndef CONTEXTMODULE__H +#define CONTEXTMODULE__H #include <torch/torch.h> #include "Submodule.hpp" +#include "MyModule.hpp" +#include "GRU.hpp" #include "LSTM.hpp" -class ContextLSTMImpl : public torch::nn::Module, public Submodule +class ContextModuleImpl : public Submodule { private : - LSTM lstm{nullptr}; + torch::nn::Embedding wordEmbeddings{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<int> bufferContext; std::vector<int> stackContext; @@ -18,13 +21,13 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule public : - ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold); + ContextModuleImpl(const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; -TORCH_MODULE(ContextLSTM); +TORCH_MODULE(ContextModule); #endif diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp similarity index 59% rename from torch_modules/include/DepthLayerTreeEmbedding.hpp rename to torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 2a8f7e8ca0ccd8fea1313e4b4437700c9bdd6bef..cd6c33df504b5e977ae1cece09f6bf92dec60e7e 100644 --- a/torch_modules/include/DepthLayerTreeEmbedding.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -3,9 +3,11 @@ #include <torch/torch.h> #include "Submodule.hpp" +#include "MyModule.hpp" #include "LSTM.hpp" +#include "GRU.hpp" -class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule +class DepthLayerTreeEmbeddingModule : public Submodule { private : @@ -13,17 +15,16 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; - std::vector<LSTM> depthLstm; + std::vector<std::shared_ptr<MyModule>> depthModules; public : - DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options); + DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; -TORCH_MODULE(DepthLayerTreeEmbedding); #endif diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnModule.hpp similarity index 54% rename from torch_modules/include/FocusedColumnLSTM.hpp rename to torch_modules/include/FocusedColumnModule.hpp index fd5d915df6d42d24294e6a75dd42c87d6e81dec1..9c2732aa82d0b23490a3d36ef527c26aa55a1822 100644 --- a/torch_modules/include/FocusedColumnLSTM.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -1,28 +1,29 @@ -#ifndef FOCUSEDCOLUMNLSTM__H -#define FOCUSEDCOLUMNLSTM__H +#ifndef FOCUSEDCOLUMNMODULE__H +#define FOCUSEDCOLUMNMODULE__H #include <torch/torch.h> #include "Submodule.hpp" +#include "MyModule.hpp" #include "LSTM.hpp" +#include "GRU.hpp" -class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule +class FocusedColumnModule : public Submodule { private : - LSTM lstm{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> focusedBuffer, focusedStack; std::string column; int maxNbElements; public : - FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; -TORCH_MODULE(FocusedColumnLSTM); #endif diff --git a/torch_modules/include/GRU.hpp b/torch_modules/include/GRU.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7980db43d5a8a808b2af96ffe59c12ec37fa539e --- /dev/null +++ b/torch_modules/include/GRU.hpp @@ -0,0 +1,23 @@ +#ifndef GRU__H +#define GRU__H + +#include <torch/torch.h> +#include "MyModule.hpp" + +class GRUImpl : public MyModule +{ + private : + + torch::nn::GRU gru{nullptr}; + bool outputAll; + + public : + + GRUImpl(int inputSize, int outputSize, ModuleOptions options); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(int sequenceLength); +}; +TORCH_MODULE(GRU); + +#endif + diff --git a/torch_modules/include/GRUNetwork.hpp b/torch_modules/include/GRUNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ecff8a05c7604c45b04fbc9f2caf6c0637dc7558 --- /dev/null +++ b/torch_modules/include/GRUNetwork.hpp @@ -0,0 +1,36 @@ +#ifndef GRUNETWORK__H +#define GRUNETWORK__H + +#include "NeuralNetwork.hpp" +#include "ContextModule.hpp" +#include "RawInputModule.hpp" +#include "SplitTransModule.hpp" +#include "FocusedColumnModule.hpp" +#include "DepthLayerTreeEmbeddingModule.hpp" +#include "MLP.hpp" + +class GRUNetworkImpl : public NeuralNetworkImpl +{ +// private : +// +// torch::nn::Embedding wordEmbeddings{nullptr}; +// torch::nn::Dropout2d embeddingsDropout2d{nullptr}; +// torch::nn::Dropout embeddingsDropout{nullptr}; +// torch::nn::Dropout inputDropout{nullptr}; +// +// MLP mlp{nullptr}; +// ContextModule contextGRU{nullptr}; +// RawInputModule rawInputGRU{nullptr}; +// SplitTransModule splitTransGRU{nullptr}; +// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr}; +// std::vector<FocusedColumnModule> focusedLstms; +// std::map<std::string,torch::nn::Linear> outputLayersPerState; + + public : + + GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); + torch::Tensor forward(torch::Tensor input) override; + std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; +}; + +#endif diff --git a/torch_modules/include/LSTM.hpp b/torch_modules/include/LSTM.hpp index c45cb9fff41aa6f9fc0883bcf7e37c9db5e3a8e9..7e4ed0c2e408cbb6cbed16b33082aec48bb69105 100644 --- a/torch_modules/include/LSTM.hpp +++ b/torch_modules/include/LSTM.hpp @@ -2,14 +2,10 @@ #define LSTM__H #include <torch/torch.h> -#include "fmt/core.h" +#include "MyModule.hpp" -class LSTMImpl : public torch::nn::Module +class LSTMImpl : public MyModule { - public : - - using LSTMOptions = std::tuple<bool,bool,int,float,bool>; - private : torch::nn::LSTM lstm{nullptr}; @@ -17,7 +13,7 @@ class LSTMImpl : public torch::nn::Module public : - LSTMImpl(int inputSize, int outputSize, LSTMOptions options); + LSTMImpl(int inputSize, int outputSize, ModuleOptions options); torch::Tensor forward(torch::Tensor input); int getOutputSize(int sequenceLength); }; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 76b9303a8cb7245fe27e42d9b7f9673a832ef2e3..d83acf5d0ada988972ca43e749644ad16fc49bae 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -2,29 +2,29 @@ #define LSTMNETWORK__H #include "NeuralNetwork.hpp" -#include "ContextLSTM.hpp" -#include "RawInputLSTM.hpp" -#include "SplitTransLSTM.hpp" -#include "FocusedColumnLSTM.hpp" +#include "ContextModule.hpp" +#include "RawInputModule.hpp" +#include "SplitTransModule.hpp" +#include "FocusedColumnModule.hpp" +#include "DepthLayerTreeEmbeddingModule.hpp" #include "MLP.hpp" -#include "DepthLayerTreeEmbedding.hpp" class LSTMNetworkImpl : public NeuralNetworkImpl { - private : - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Dropout2d embeddingsDropout2d{nullptr}; - torch::nn::Dropout embeddingsDropout{nullptr}; - torch::nn::Dropout inputDropout{nullptr}; - - MLP mlp{nullptr}; - ContextLSTM contextLSTM{nullptr}; - RawInputLSTM rawInputLSTM{nullptr}; - SplitTransLSTM splitTransLSTM{nullptr}; - DepthLayerTreeEmbedding treeEmbedding{nullptr}; - std::vector<FocusedColumnLSTM> focusedLstms; - std::map<std::string,torch::nn::Linear> outputLayersPerState; +// private : +// +// torch::nn::Embedding wordEmbeddings{nullptr}; +// torch::nn::Dropout2d embeddingsDropout2d{nullptr}; +// torch::nn::Dropout embeddingsDropout{nullptr}; +// torch::nn::Dropout inputDropout{nullptr}; +// +// MLP mlp{nullptr}; +// ContextModule contextLSTM{nullptr}; +// RawInputModule rawInputLSTM{nullptr}; +// SplitTransModule splitTransLSTM{nullptr}; +// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr}; +// std::vector<FocusedColumnModule> focusedLstms; +// std::map<std::string,torch::nn::Linear> outputLayersPerState; public : diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp index be272f1cd1369a7b1290aefd1868265111ac00da..bd108a461107e1fa4c0bb609c733f03515b3e785 100644 --- a/torch_modules/include/MLP.hpp +++ b/torch_modules/include/MLP.hpp @@ -13,7 +13,7 @@ class MLPImpl : public torch::nn::Module public : - MLPImpl(int inputSize, std::vector<std::pair<int, float>> params); + MLPImpl(int inputSize, std::string definition); torch::Tensor forward(torch::Tensor input); std::size_t outputSize() const; }; diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp new file mode 100644 index 0000000000000000000000000000000000000000..871c8bb3c1f00652a0b5a6849a412811c19d9094 --- /dev/null +++ b/torch_modules/include/ModularNetwork.hpp @@ -0,0 +1,32 @@ +#ifndef MODULARNETWORK__H +#define MODULARNETWORK__H + +#include "NeuralNetwork.hpp" +#include "ContextModule.hpp" +#include "RawInputModule.hpp" +#include "SplitTransModule.hpp" +#include "FocusedColumnModule.hpp" +#include "DepthLayerTreeEmbeddingModule.hpp" +#include "MLP.hpp" + +class ModularNetworkImpl : public NeuralNetworkImpl +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Dropout2d embeddingsDropout2d{nullptr}; + torch::nn::Dropout embeddingsDropout{nullptr}; + torch::nn::Dropout inputDropout{nullptr}; + + MLP mlp{nullptr}; + std::vector<std::shared_ptr<Submodule>> modules; + std::map<std::string,torch::nn::Linear> outputLayersPerState; + + public : + + ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); + torch::Tensor forward(torch::Tensor input) override; + std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; +}; + +#endif diff --git a/torch_modules/include/MyModule.hpp b/torch_modules/include/MyModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..02018a3473b892121506181c24bdc5642827f3e6 --- /dev/null +++ b/torch_modules/include/MyModule.hpp @@ -0,0 +1,25 @@ +#ifndef MYMODULE__H +#define MYMODULE__H + +#include <torch/torch.h> + +class MyModule : public torch::nn::Module +{ + public : + + struct ModuleOptions : std::tuple<bool,bool,int,float,bool> + { + ModuleOptions(bool batchFirst){std::get<0>(*this)=batchFirst;}; + ModuleOptions & bidirectional(bool val) {std::get<1>(*this)=val; return *this;} + ModuleOptions & num_layers(int num) {std::get<2>(*this)=num; return *this;} + ModuleOptions & dropout(float val) {std::get<3>(*this)=val; return *this;} + ModuleOptions & complete(bool val) {std::get<4>(*this)=val; return *this;} + }; + + public : + + virtual int getOutputSize(int sequenceLength) = 0; + virtual torch::Tensor forward(torch::Tensor) = 0; +}; + +#endif diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputModule.hpp similarity index 56% rename from torch_modules/include/RawInputLSTM.hpp rename to torch_modules/include/RawInputModule.hpp index 0e08560836b735f181849571ff0beec8f02bc335..0134f974d7640f77cc8334f2828c71863a939230 100644 --- a/torch_modules/include/RawInputLSTM.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -1,26 +1,27 @@ -#ifndef RAWINPUTLSTM__H -#define RAWINPUTLSTM__H +#ifndef RAWINPUTMODULE__H +#define RAWINPUTMODULE__H #include <torch/torch.h> #include "Submodule.hpp" +#include "MyModule.hpp" #include "LSTM.hpp" +#include "GRU.hpp" -class RawInputLSTMImpl : public torch::nn::Module, public Submodule +class RawInputModule : public Submodule { private : - LSTM lstm{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; int leftWindow, rightWindow; public : - RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; -TORCH_MODULE(RawInputLSTM); #endif diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransModule.hpp similarity index 56% rename from torch_modules/include/SplitTransLSTM.hpp rename to torch_modules/include/SplitTransModule.hpp index 85d542ce8510bd0c1d11b2ca6c1f280aeb386d55..8be8e633fab7f182242940eee1308b3cabe0d676 100644 --- a/torch_modules/include/SplitTransLSTM.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -1,26 +1,27 @@ -#ifndef SPLITTRANSLSTM__H -#define SPLITTRANSLSTM__H +#ifndef SPLITTRANSMODULE__H +#define SPLITTRANSMODULE__H #include <torch/torch.h> #include "Submodule.hpp" +#include "MyModule.hpp" #include "LSTM.hpp" +#include "GRU.hpp" -class SplitTransLSTMImpl : public torch::nn::Module, public Submodule +class SplitTransModule : public Submodule { private : - LSTM lstm{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; int maxNbTrans; public : - SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; -TORCH_MODULE(SplitTransLSTM); #endif diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index cc381013aea518aeefe8422b36537283d5d0da94..77c1a4feb08628615d1d163369f0a9272970d475 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -1,10 +1,11 @@ #ifndef SUBMODULE__H #define SUBMODULE__H +#include <torch/torch.h> #include "Dict.hpp" #include "Config.hpp" -class Submodule +class Submodule : public torch::nn::Module { protected : @@ -16,6 +17,7 @@ class Submodule virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0; + virtual torch::Tensor forward(torch::Tensor input) = 0; }; #endif diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp deleted file mode 100644 index d24778878ec51303d55b3b0e7a3bc25f0fbdc9cc..0000000000000000000000000000000000000000 --- a/torch_modules/src/ContextLSTM.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "ContextLSTM.hpp" - -ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold) -{ - lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)); -} - -std::size_t ContextLSTMImpl::getOutputSize() -{ - return lstm->getOutputSize(bufferContext.size()+stackContext.size()); -} - -std::size_t ContextLSTMImpl::getInputSize() -{ - return columns.size()*(bufferContext.size()+stackContext.size()); -} - -void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const -{ - std::vector<long> contextIndexes; - - for (int index : bufferContext) - contextIndexes.emplace_back(config.getRelativeWordIndex(index)); - - for (int index : stackContext) - if (config.hasStack(index)) - contextIndexes.emplace_back(config.getStack(index)); - else - contextIndexes.emplace_back(-1); - - for (auto index : contextIndexes) - for (auto & col : columns) - if (index == -1) - { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - } - else - { - int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); - - for (auto & contextElement : context) - contextElement.push_back(dictIndex); - - for (auto & targetCol : unknownValueColumns) - if (col == targetCol) - if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) - context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); - } -} - -torch::Tensor ContextLSTMImpl::forward(torch::Tensor input) -{ - auto context = input.narrow(1, firstInputIndex, getInputSize()); - - context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); - - return lstm(context); -} - diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2c1e6af087e2696960a6cdfefc03a9e3827f2b6 --- /dev/null +++ b/torch_modules/src/ContextModule.cpp @@ -0,0 +1,98 @@ +#include "ContextModule.hpp" + +ContextModuleImpl::ContextModuleImpl(const std::string & definition) +{ + std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + unknownValueThreshold = std::stoi(sm.str(1)); + + for (auto & index : util::split(sm.str(2), ' ')) + bufferContext.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(3), ' ')) + stackContext.emplace_back(std::stoi(index)); + + columns = util::split(sm.str(4), ' '); + + auto subModuleType = sm.str(5); + auto subModuleArguments = util::split(sm.str(6), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int inSize = std::stoi(sm.str(7)); + int outSize = std::stoi(sm.str(8)); + + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +std::size_t ContextModuleImpl::getOutputSize() +{ + return myModule->getOutputSize(bufferContext.size()+stackContext.size()); +} + +std::size_t ContextModuleImpl::getInputSize() +{ + return columns.size()*(bufferContext.size()+stackContext.size()); +} + +void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const +{ + std::vector<long> contextIndexes; + + for (int index : bufferContext) + contextIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : stackContext) + if (config.hasStack(index)) + contextIndexes.emplace_back(config.getStack(index)); + else + contextIndexes.emplace_back(-1); + + for (auto index : contextIndexes) + for (auto & col : columns) + if (index == -1) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } + else + { + int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + + for (auto & contextElement : context) + contextElement.push_back(dictIndex); + + for (auto & targetCol : unknownValueColumns) + if (col == targetCol) + if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) + context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); + } +} + +torch::Tensor ContextModuleImpl::forward(torch::Tensor input) +{ + auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())); + + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); + + return myModule->forward(context); +} + diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp similarity index 68% rename from torch_modules/src/DepthLayerTreeEmbedding.cpp rename to torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index b506f9219fd8284094960907975294fbd3a5b28a..f3d7bb2881e5a3d187aa2e95d426168bbabce976 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -1,13 +1,13 @@ -#include "DepthLayerTreeEmbedding.hpp" +#include "DepthLayerTreeEmbeddingModule.hpp" -DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) : +DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) : maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack) { for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) - depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options))); + depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options))); } -torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) +torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); @@ -17,24 +17,24 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++) for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) { - outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)}))); + outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)}))); offset += maxElemPerDepth[depth]*columns.size(); } return torch::cat(outputs, 1); } -std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize() +std::size_t DepthLayerTreeEmbeddingModule::getOutputSize() { std::size_t outputSize = 0; for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) - outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]); + outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]); return outputSize*(focusedBuffer.size()+focusedStack.size()); } -std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() +std::size_t DepthLayerTreeEmbeddingModule::getInputSize() { int inputSize = 0; for (int maxElem : maxElemPerDepth) @@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnModule.cpp similarity index 72% rename from torch_modules/src/FocusedColumnLSTM.cpp rename to torch_modules/src/FocusedColumnModule.cpp index e39af636c817fdc1677cfd9131b85ec7fb1bd3ba..717bc747c85be8521271d80499698af525741585 100644 --- a/torch_modules/src/FocusedColumnLSTM.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -1,30 +1,30 @@ -#include "FocusedColumnLSTM.hpp" +#include "FocusedColumnModule.hpp" -FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements) +FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements) { - lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); + myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); } -torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input) +torch::Tensor FocusedColumnModule::forward(torch::Tensor input) { std::vector<torch::Tensor> outputs; for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++) - outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))); + outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))); return torch::cat(outputs, 1); } -std::size_t FocusedColumnLSTMImpl::getOutputSize() +std::size_t FocusedColumnModule::getOutputSize() { - return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements); + return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements); } -std::size_t FocusedColumnLSTMImpl::getInputSize() +std::size_t FocusedColumnModule::getInputSize() { return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/GRU.cpp b/torch_modules/src/GRU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa6de5ccf0d594166f2544465dc5903a204692d4 --- /dev/null +++ b/torch_modules/src/GRU.cpp @@ -0,0 +1,34 @@ +#include "GRU.hpp" + +GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options)) +{ + auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize) + .batch_first(std::get<0>(options)) + .bidirectional(std::get<1>(options)) + .num_layers(std::get<2>(options)) + .dropout(std::get<3>(options)); + + gru = register_module("gru", torch::nn::GRU(gruOptions)); +} + +torch::Tensor GRUImpl::forward(torch::Tensor input) +{ + auto gruOut = std::get<0>(gru(input)); + + if (outputAll) + return gruOut.reshape({gruOut.size(0), -1}); + + if (gru->options.bidirectional()) + return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1); + + return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1); +} + +int GRUImpl::getOutputSize(int sequenceLength) +{ + if (outputAll) + return sequenceLength * gru->options.hidden_size() * (gru->options.bidirectional() ? 2 : 1); + + return gru->options.hidden_size() * (gru->options.bidirectional() ? 4 : 1); +} + diff --git a/torch_modules/src/GRUNetwork.cpp b/torch_modules/src/GRUNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc7f36609d5a83f94486ee30dc3025bc30ca5bcc --- /dev/null +++ b/torch_modules/src/GRUNetwork.cpp @@ -0,0 +1,120 @@ +#include "GRUNetwork.hpp" + +GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) +{ +// MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false}; +// auto gruOptionsAll = gruOptions; +// std::get<4>(gruOptionsAll) = true; +// +// int currentOutputSize = embeddingsSize; +// int currentInputSize = 1; +// +// contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold)); +// contextGRU->setFirstInputIndex(currentInputSize); +// currentOutputSize += contextGRU->getOutputSize(); +// currentInputSize += contextGRU->getInputSize(); +// +// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) +// { +// rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll)); +// rawInputGRU->setFirstInputIndex(currentInputSize); +// currentOutputSize += rawInputGRU->getOutputSize(); +// currentInputSize += rawInputGRU->getInputSize(); +// } +// +// if (!treeEmbeddingColumns.empty()) +// { +// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions)); +// treeEmbedding->setFirstInputIndex(currentInputSize); +// currentOutputSize += treeEmbedding->getOutputSize(); +// currentInputSize += treeEmbedding->getInputSize(); +// } +// +// splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll)); +// splitTransGRU->setFirstInputIndex(currentInputSize); +// currentOutputSize += splitTransGRU->getOutputSize(); +// currentInputSize += splitTransGRU->getInputSize(); +// +// for (unsigned int i = 0; i < focusedColumns.size(); i++) +// { +// focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions))); +// focusedLstms.back()->setFirstInputIndex(currentInputSize); +// currentOutputSize += focusedLstms.back()->getOutputSize(); +// currentInputSize += focusedLstms.back()->getInputSize(); +// } +// +// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); +// if (drop2d) +// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); +// else +// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); +// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); +// +// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); +// +// for (auto & it : nbOutputsPerState) +// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); +} + +torch::Tensor GRUNetworkImpl::forward(torch::Tensor input) +{ + return input; +// if (input.dim() == 1) +// input = input.unsqueeze(0); +// +// auto embeddings = wordEmbeddings(input); +// if (embeddingsDropout2d.is_empty()) +// embeddings = embeddingsDropout(embeddings); +// else +// embeddings = embeddingsDropout2d(embeddings); +// +// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; +// +// outputs.emplace_back(contextGRU(embeddings)); +// +// if (!rawInputGRU.is_empty()) +// outputs.emplace_back(rawInputGRU(embeddings)); +// +// if (!treeEmbedding.is_empty()) +// outputs.emplace_back(treeEmbedding(embeddings)); +// +// outputs.emplace_back(splitTransGRU(embeddings)); +// +// for (auto & gru : focusedLstms) +// outputs.emplace_back(gru(embeddings)); +// +// auto totalInput = inputDropout(torch::cat(outputs, 1)); +// +// return outputLayersPerState.at(getState())(mlp(totalInput)); +} + +std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::vector<std::vector<long>> context; + return context; +// if (dict.size() >= maxNbEmbeddings) +// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); +// +// context.emplace_back(); +// +// context.back().emplace_back(dict.getIndexOrInsert(config.getState())); +// +// contextGRU->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!rawInputGRU.is_empty()) +// rawInputGRU->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!treeEmbedding.is_empty()) +// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); +// +// splitTransGRU->addToContext(context, dict, config, mustSplitUnknown()); +// +// for (auto & gru : focusedLstms) +// gru->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!mustSplitUnknown() && context.size() > 1) +// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); +// +// return context; +} + diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index af89a3dedddc3750451f75442213eeb52482dfda..2844b17a256bac5de90017343fc4c7b2ad466e89 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -1,6 +1,6 @@ #include "LSTM.hpp" -LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options)) +LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options)) { auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) .batch_first(std::get<0>(options)) diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 476e24dc938bd302db62194fae0e4cb665f82dab..e46d175aec9d5e83f1d365d40cdd0b7e5c3dc977 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -2,117 +2,119 @@ LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) { - LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; - auto lstmOptionsAll = lstmOptions; - std::get<4>(lstmOptionsAll) = true; - - int currentOutputSize = embeddingsSize; - int currentInputSize = 1; - - contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold)); - contextLSTM->setFirstInputIndex(currentInputSize); - currentOutputSize += contextLSTM->getOutputSize(); - currentInputSize += contextLSTM->getInputSize(); - - if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) - { - rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); - rawInputLSTM->setFirstInputIndex(currentInputSize); - currentOutputSize += rawInputLSTM->getOutputSize(); - currentInputSize += rawInputLSTM->getInputSize(); - } - - if (!treeEmbeddingColumns.empty()) - { - treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions)); - treeEmbedding->setFirstInputIndex(currentInputSize); - currentOutputSize += treeEmbedding->getOutputSize(); - currentInputSize += treeEmbedding->getInputSize(); - } - - splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll)); - splitTransLSTM->setFirstInputIndex(currentInputSize); - currentOutputSize += splitTransLSTM->getOutputSize(); - currentInputSize += splitTransLSTM->getInputSize(); - - for (unsigned int i = 0; i < focusedColumns.size(); i++) - { - focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions))); - focusedLstms.back()->setFirstInputIndex(currentInputSize); - currentOutputSize += focusedLstms.back()->getOutputSize(); - currentInputSize += focusedLstms.back()->getInputSize(); - } - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - if (drop2d) - embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); - else - embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); - inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); - - mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); - - for (auto & it : nbOutputsPerState) - outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); +// MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false}; +// auto moduleOptionsAll = moduleOptions; +// std::get<4>(moduleOptionsAll) = true; +// +// int currentOutputSize = embeddingsSize; +// int currentInputSize = 1; +// +// contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold)); +// contextLSTM->setFirstInputIndex(currentInputSize); +// currentOutputSize += contextLSTM->getOutputSize(); +// currentInputSize += contextLSTM->getInputSize(); +// +// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) +// { +// rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll)); +// rawInputLSTM->setFirstInputIndex(currentInputSize); +// currentOutputSize += rawInputLSTM->getOutputSize(); +// currentInputSize += rawInputLSTM->getInputSize(); +// } +// +// if (!treeEmbeddingColumns.empty()) +// { +// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions)); +// treeEmbedding->setFirstInputIndex(currentInputSize); +// currentOutputSize += treeEmbedding->getOutputSize(); +// currentInputSize += treeEmbedding->getInputSize(); +// } +// +// splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll)); +// splitTransLSTM->setFirstInputIndex(currentInputSize); +// currentOutputSize += splitTransLSTM->getOutputSize(); +// currentInputSize += splitTransLSTM->getInputSize(); +// +// for (unsigned int i = 0; i < focusedColumns.size(); i++) +// { +// focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions))); +// focusedLstms.back()->setFirstInputIndex(currentInputSize); +// currentOutputSize += focusedLstms.back()->getOutputSize(); +// currentInputSize += focusedLstms.back()->getInputSize(); +// } +// +// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); +// if (drop2d) +// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); +// else +// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); +// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); +// +// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); +// +// for (auto & it : nbOutputsPerState) +// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); } torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) { - if (input.dim() == 1) - input = input.unsqueeze(0); - - auto embeddings = wordEmbeddings(input); - if (embeddingsDropout2d.is_empty()) - embeddings = embeddingsDropout(embeddings); - else - embeddings = embeddingsDropout2d(embeddings); - - std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; - - outputs.emplace_back(contextLSTM(embeddings)); - - if (!rawInputLSTM.is_empty()) - outputs.emplace_back(rawInputLSTM(embeddings)); - - if (!treeEmbedding.is_empty()) - outputs.emplace_back(treeEmbedding(embeddings)); - - outputs.emplace_back(splitTransLSTM(embeddings)); - - for (auto & lstm : focusedLstms) - outputs.emplace_back(lstm(embeddings)); - - auto totalInput = inputDropout(torch::cat(outputs, 1)); - - return outputLayersPerState.at(getState())(mlp(totalInput)); + return input; +// if (input.dim() == 1) +// input = input.unsqueeze(0); +// +// auto embeddings = wordEmbeddings(input); +// if (embeddingsDropout2d.is_empty()) +// embeddings = embeddingsDropout(embeddings); +// else +// embeddings = embeddingsDropout2d(embeddings); +// +// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; +// +// outputs.emplace_back(contextLSTM(embeddings)); +// +// if (!rawInputLSTM.is_empty()) +// outputs.emplace_back(rawInputLSTM(embeddings)); +// +// if (!treeEmbedding.is_empty()) +// outputs.emplace_back(treeEmbedding(embeddings)); +// +// outputs.emplace_back(splitTransLSTM(embeddings)); +// +// for (auto & lstm : focusedLstms) +// outputs.emplace_back(lstm(embeddings)); +// +// auto totalInput = inputDropout(torch::cat(outputs, 1)); +// +// return outputLayersPerState.at(getState())(mlp(totalInput)); } std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const { - if (dict.size() >= maxNbEmbeddings) - util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); - std::vector<std::vector<long>> context; - context.emplace_back(); - - context.back().emplace_back(dict.getIndexOrInsert(config.getState())); - - contextLSTM->addToContext(context, dict, config, mustSplitUnknown()); - - if (!rawInputLSTM.is_empty()) - rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown()); - - if (!treeEmbedding.is_empty()) - treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); - - splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown()); - - for (auto & lstm : focusedLstms) - lstm->addToContext(context, dict, config, mustSplitUnknown()); - - if (!mustSplitUnknown() && context.size() > 1) - util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); - return context; +// if (dict.size() >= maxNbEmbeddings) +// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); +// +// context.emplace_back(); +// +// context.back().emplace_back(dict.getIndexOrInsert(config.getState())); +// +// contextLSTM->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!rawInputLSTM.is_empty()) +// rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!treeEmbedding.is_empty()) +// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); +// +// splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown()); +// +// for (auto & lstm : focusedLstms) +// lstm->addToContext(context, dict, config, mustSplitUnknown()); +// +// if (!mustSplitUnknown() && context.size() > 1) +// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); +// +// return context; } diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index 03880ecca0554d1ab4ad1b0c92fca4f1a0ea2481..ede2ff63f46ee50a640a8e041389974caa5fff64 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -1,8 +1,25 @@ #include "MLP.hpp" +#include "util.hpp" #include "fmt/core.h" +#include <regex> -MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params) +MLPImpl::MLPImpl(int inputSize, std::string definition) { + std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::vector<std::pair<int, float>> params; + if (!util::doIfNameMatch(regex, definition, [this,&definition,¶ms](auto sm) + { + try + { + auto splited = util::split(sm.str(1), ' '); + for (unsigned int i = 0; i < splited.size()/2; i++) + { + params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1])); + } + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); + int inSize = inputSize; for (auto & param : params) diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21edb84ffb337b5f0bf544e5bf71580153888335 --- /dev/null +++ b/torch_modules/src/ModularNetwork.cpp @@ -0,0 +1,77 @@ +#include "ModularNetwork.hpp" + +ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions) +{ + std::string anyBlanks = "(?:(?:\\s|\\t)*)"; + auto splitLine = [anyBlanks](std::string line) + { + std::pair<std::string,std::string> result; + util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm) + { + result.first = sm.str(1); + result.second = sm.str(2); + }); + return result; + }; + + int currentInputSize = 0; + int currentOutputSize = 0; + std::string mlpDef; + for (auto & line : definitions) + { + auto splited = splitLine(line); + std::string name = fmt::format("{}_{}", modules.size(), splited.first); + if (splited.first == "Context") + modules.emplace_back(register_module(name, ContextModule(splited.second))); + else if (splited.first == "MLP") + { + mlpDef = splited.second; + continue; + } + else if (splited.first == "InputDropout") + { + inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second))); + continue; + } + else + util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line)); + + modules.back()->setFirstInputIndex(currentInputSize); + currentInputSize += modules.back()->getInputSize(); + currentOutputSize += modules.back()->getOutputSize(); + } + + if (mlpDef.empty()) + util::myThrow("no MLP definition found"); + if (inputDropout.is_empty()) + util::myThrow("no InputDropout definition found"); + + mlp = register_module("mlp", MLP(currentOutputSize, mlpDef)); + + for (auto & it : nbOutputsPerState) + outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); +} + +torch::Tensor ModularNetworkImpl::forward(torch::Tensor input) +{ + if (input.dim() == 1) + input = input.unsqueeze(0); + + std::vector<torch::Tensor> outputs; + + for (auto & mod : modules) + outputs.emplace_back(mod->forward(input)); + + auto totalInput = inputDropout(torch::cat(outputs, 1)); + + return outputLayersPerState.at(getState())(mlp(totalInput)); +} + +std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config, Dict & dict) const +{ + std::vector<std::vector<long>> context(1); + for (auto & mod : modules) + mod->addToContext(context, dict, config, mustSplitUnknown()); + return context; +} + diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputModule.cpp similarity index 53% rename from torch_modules/src/RawInputLSTM.cpp rename to torch_modules/src/RawInputModule.cpp index c6da426a7807b90bfd52eaf06abe7599c4c517c3..e14451dd206b2dea8a3d50f91a603fca97b74e26 100644 --- a/torch_modules/src/RawInputLSTM.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -1,26 +1,26 @@ -#include "RawInputLSTM.hpp" +#include "RawInputModule.hpp" -RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow) +RawInputModule::RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : leftWindow(leftWindow), rightWindow(rightWindow) { - lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); + myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); } -torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input) +torch::Tensor RawInputModule::forward(torch::Tensor input) { - return lstm(input.narrow(1, firstInputIndex, getInputSize())); + return myModule->forward(input.narrow(1, firstInputIndex, getInputSize())); } -std::size_t RawInputLSTMImpl::getOutputSize() +std::size_t RawInputModule::getOutputSize() { - return lstm->getOutputSize(leftWindow + rightWindow + 1); + return myModule->getOutputSize(leftWindow + rightWindow + 1); } -std::size_t RawInputLSTMImpl::getInputSize() +std::size_t RawInputModule::getInputSize() { return leftWindow + rightWindow + 1; } -void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void RawInputModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { if (leftWindow < 0 or rightWindow < 0) return; diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp deleted file mode 100644 index 99a1b35650e0b60c8c34c22f0a863d1ab1f8c990..0000000000000000000000000000000000000000 --- a/torch_modules/src/SplitTransLSTM.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "SplitTransLSTM.hpp" -#include "Transition.hpp" - -SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans) -{ - lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); -} - -torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input) -{ - return lstm(input.narrow(1, firstInputIndex, getInputSize())); -} - -std::size_t SplitTransLSTMImpl::getOutputSize() -{ - return lstm->getOutputSize(maxNbTrans); -} - -std::size_t SplitTransLSTMImpl::getInputSize() -{ - return maxNbTrans; -} - -void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const -{ - auto & splitTransitions = config.getAppliableSplitTransitions(); - for (auto & contextElement : context) - for (int i = 0; i < maxNbTrans; i++) - if (i < (int)splitTransitions.size()) - contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); - else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); -} - diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45a48df4488e884af925d6b50bfca69bdba73e60 --- /dev/null +++ b/torch_modules/src/SplitTransModule.cpp @@ -0,0 +1,34 @@ +#include "SplitTransModule.hpp" +#include "Transition.hpp" + +SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans) +{ + myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); +} + +torch::Tensor SplitTransModule::forward(torch::Tensor input) +{ + return myModule->forward(input.narrow(1, firstInputIndex, getInputSize())); +} + +std::size_t SplitTransModule::getOutputSize() +{ + return myModule->getOutputSize(maxNbTrans); +} + +std::size_t SplitTransModule::getInputSize() +{ + return maxNbTrans; +} + +void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +{ + auto & splitTransitions = config.getAppliableSplitTransitions(); + for (auto & contextElement : context) + for (int i = 0; i < maxNbTrans; i++) + if (i < (int)splitTransitions.size()) + contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} +