diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 0684d56fcad4bf1e2866103378bc4e9878354667..29ff70469d333299dbbc553d44861976c0e800db 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -18,8 +18,6 @@ class Classifier private : void initNeuralNetwork(const std::vector<std::string> & definition); - void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); - void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); public : diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 113ee9bda8932ccb86e95e8521eb5292dc5d429f..64e41937ae2943f01fdaf7a1200671dbeb360f8f 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -1,7 +1,5 @@ #include "Classifier.hpp" #include "util.hpp" -#include "LSTMNetwork.hpp" -#include "GRUNetwork.hpp" #include "RandomNetwork.hpp" #include "ModularNetwork.hpp" @@ -87,10 +85,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) if (networkType == "Random") this->nn.reset(new RandomNetworkImpl(nbOutputsPerState)); - else if (networkType == "LSTM") - initLSTM(definition, curIndex, nbOutputsPerState); - else if (networkType == "GRU") - initGRU(definition, curIndex, nbOutputsPerState); else if (networkType == "Modular") initModular(definition, curIndex, nbOutputsPerState); else @@ -115,225 +109,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } -void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) -{ - int unknownValueThreshold; - std::vector<int> bufferContext, stackContext; - std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns; - std::vector<int> focusedBuffer, focusedStack; - std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack; - std::vector<int> maxNbElements; - std::vector<int> treeEmbeddingNbElems; - std::vector<std::pair<int, float>> mlp; - int rawInputLeftWindow, rawInputRightWindow; - int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; - bool bilstm, drop2d; - float lstmDropout, embeddingsDropout, totalInputDropout; - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) - { - unknownValueThreshold = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - bufferContext.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - stackContext.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm) - { - columns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - focusedBuffer.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - focusedStack.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm) - { - focusedColumns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - maxNbElements.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm) - { - rawInputLeftWindow = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm) - { - rawInputRightWindow = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm) - { - embeddingsSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm) - { - auto params = util::split(sm.str(1), ' '); - if (params.size() % 2) - util::myThrow("MLP must have even number of parameters"); - for (unsigned int i = 0; i < params.size()/2; i++) - mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1]))); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm) - { - contextLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm) - { - focusedLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm) - { - rawInputLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm) - { - splitTransLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm) - { - nbLayers = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm) - { - bilstm = sm.str(1) == "true"; - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm) - { - lstmDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm) - { - totalInputDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm) - { - embeddingsDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm) - { - drop2d = sm.str(1) == "true"; - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) - { - treeEmbeddingColumns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingBuffer.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingStack.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingNbElems.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm) - { - treeEmbeddingSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); - - this->nn.reset(new LSTMNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); -} - void Classifier::loadOptimizer(std::filesystem::path path) { torch::load(*optimizer, path); @@ -355,225 +130,6 @@ void Classifier::setState(const std::string & state) nn->setState(state); } -void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) -{ - int unknownValueThreshold; - std::vector<int> bufferContext, stackContext; - std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns; - std::vector<int> focusedBuffer, focusedStack; - std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack; - std::vector<int> maxNbElements; - std::vector<int> treeEmbeddingNbElems; - std::vector<std::pair<int, float>> mlp; - int rawInputLeftWindow, rawInputRightWindow; - int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; - bool bilstm, drop2d; - float lstmDropout, embeddingsDropout, totalInputDropout; - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) - { - unknownValueThreshold = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - bufferContext.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - stackContext.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm) - { - columns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - focusedBuffer.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - focusedStack.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm) - { - focusedColumns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - maxNbElements.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm) - { - rawInputLeftWindow = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm) - { - rawInputRightWindow = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm) - { - embeddingsSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm) - { - auto params = util::split(sm.str(1), ' '); - if (params.size() % 2) - util::myThrow("MLP must have even number of parameters"); - for (unsigned int i = 0; i < params.size()/2; i++) - mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1]))); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm) - { - contextLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm) - { - focusedLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm) - { - rawInputLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm) - { - splitTransLSTMSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm) - { - nbLayers = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm) - { - bilstm = sm.str(1) == "true"; - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm) - { - lstmDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm) - { - totalInputDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm) - { - embeddingsDropout = std::stof(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm) - { - drop2d = sm.str(1) == "true"; - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) - { - treeEmbeddingColumns = util::split(sm.str(1), ' '); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingBuffer.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingStack.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm) - { - for (auto & index : util::split(sm.str(1), ' ')) - treeEmbeddingNbElems.emplace_back(std::stoi(index)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}")); - - if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm) - { - treeEmbeddingSize = std::stoi(sm.str(1)); - curIndex++; - })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); - - this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); -} - void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index cd6c33df504b5e977ae1cece09f6bf92dec60e7e..0d5cedda715558c166412d059bab28ce50d379c2 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -7,7 +7,7 @@ #include "LSTM.hpp" #include "GRU.hpp" -class DepthLayerTreeEmbeddingModule : public Submodule +class DepthLayerTreeEmbeddingModuleImpl : public Submodule { private : @@ -15,16 +15,18 @@ class DepthLayerTreeEmbeddingModule : public Submodule std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; + torch::nn::Embedding wordEmbeddings{nullptr}; std::vector<std::shared_ptr<MyModule>> depthModules; public : - DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options); + DepthLayerTreeEmbeddingModuleImpl(const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; +TORCH_MODULE(DepthLayerTreeEmbeddingModule); #endif diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index 9c2732aa82d0b23490a3d36ef527c26aa55a1822..c105193b65c3f26fb5cfdd6e674910b70843feb0 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -7,10 +7,11 @@ #include "LSTM.hpp" #include "GRU.hpp" -class FocusedColumnModule : public Submodule +class FocusedColumnModuleImpl : public Submodule { private : + torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; std::vector<int> focusedBuffer, focusedStack; std::string column; @@ -18,12 +19,13 @@ class FocusedColumnModule : public Submodule public : - FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); + FocusedColumnModuleImpl(const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; +TORCH_MODULE(FocusedColumnModule); #endif diff --git a/torch_modules/include/GRUNetwork.hpp b/torch_modules/include/GRUNetwork.hpp deleted file mode 100644 index ecff8a05c7604c45b04fbc9f2caf6c0637dc7558..0000000000000000000000000000000000000000 --- a/torch_modules/include/GRUNetwork.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef GRUNETWORK__H -#define GRUNETWORK__H - -#include "NeuralNetwork.hpp" -#include "ContextModule.hpp" -#include "RawInputModule.hpp" -#include "SplitTransModule.hpp" -#include "FocusedColumnModule.hpp" -#include "DepthLayerTreeEmbeddingModule.hpp" -#include "MLP.hpp" - -class GRUNetworkImpl : public NeuralNetworkImpl -{ -// private : -// -// torch::nn::Embedding wordEmbeddings{nullptr}; -// torch::nn::Dropout2d embeddingsDropout2d{nullptr}; -// torch::nn::Dropout embeddingsDropout{nullptr}; -// torch::nn::Dropout inputDropout{nullptr}; -// -// MLP mlp{nullptr}; -// ContextModule contextGRU{nullptr}; -// RawInputModule rawInputGRU{nullptr}; -// SplitTransModule splitTransGRU{nullptr}; -// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr}; -// std::vector<FocusedColumnModule> focusedLstms; -// std::map<std::string,torch::nn::Linear> outputLayersPerState; - - public : - - GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); - torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; -}; - -#endif diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp deleted file mode 100644 index d83acf5d0ada988972ca43e749644ad16fc49bae..0000000000000000000000000000000000000000 --- a/torch_modules/include/LSTMNetwork.hpp +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef LSTMNETWORK__H -#define LSTMNETWORK__H - -#include "NeuralNetwork.hpp" -#include "ContextModule.hpp" -#include "RawInputModule.hpp" -#include "SplitTransModule.hpp" -#include "FocusedColumnModule.hpp" -#include "DepthLayerTreeEmbeddingModule.hpp" -#include "MLP.hpp" - -class LSTMNetworkImpl : public NeuralNetworkImpl -{ -// private : -// -// torch::nn::Embedding wordEmbeddings{nullptr}; -// torch::nn::Dropout2d embeddingsDropout2d{nullptr}; -// torch::nn::Dropout embeddingsDropout{nullptr}; -// torch::nn::Dropout inputDropout{nullptr}; -// -// MLP mlp{nullptr}; -// ContextModule contextLSTM{nullptr}; -// RawInputModule rawInputLSTM{nullptr}; -// SplitTransModule splitTransLSTM{nullptr}; -// DepthLayerTreeEmbeddingModule treeEmbedding{nullptr}; -// std::vector<FocusedColumnModule> focusedLstms; -// std::map<std::string,torch::nn::Linear> outputLayersPerState; - - public : - - LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); - torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; -}; - -#endif diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 871c8bb3c1f00652a0b5a6849a412811c19d9094..931aca8187ffaeff43871fea3413360c7e069ed8 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -13,9 +13,9 @@ class ModularNetworkImpl : public NeuralNetworkImpl { private : - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Dropout2d embeddingsDropout2d{nullptr}; - torch::nn::Dropout embeddingsDropout{nullptr}; + //torch::nn::Embedding wordEmbeddings{nullptr}; + //torch::nn::Dropout2d embeddingsDropout2d{nullptr}; + //torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr}; MLP mlp{nullptr}; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 0134f974d7640f77cc8334f2828c71863a939230..4ded9154d4536031bc36e098ee47a15335f96054 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -7,21 +7,23 @@ #include "LSTM.hpp" #include "GRU.hpp" -class RawInputModule : public Submodule +class RawInputModuleImpl : public Submodule { private : + torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int leftWindow, rightWindow; public : - RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); + RawInputModuleImpl(const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; +TORCH_MODULE(RawInputModule); #endif diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 8be8e633fab7f182242940eee1308b3cabe0d676..24c68411e5b5bf3c232b561fc77a16030326a25b 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -7,21 +7,23 @@ #include "LSTM.hpp" #include "GRU.hpp" -class SplitTransModule : public Submodule +class SplitTransModuleImpl : public Submodule { private : + torch::nn::Embedding wordEmbeddings{nullptr}; std::shared_ptr<MyModule> myModule{nullptr}; int maxNbTrans; public : - SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options); + SplitTransModuleImpl(int maxNbTrans, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override; }; +TORCH_MODULE(SplitTransModule); #endif diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index f3d7bb2881e5a3d187aa2e95d426168bbabce976..7e13cdc2d1d4447f4b6141f2932dc981aed9a905 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -1,15 +1,56 @@ #include "DepthLayerTreeEmbeddingModule.hpp" -DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) : - maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack) +DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::string & definition) { - for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) - depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options))); + std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + columns = util::split(sm.str(1), ' '); + + for (auto & index : util::split(sm.str(2), ' ')) + focusedBuffer.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(3), ' ')) + focusedStack.emplace_back(std::stoi(index)); + + for (auto & elem : util::split(sm.str(4), ' ')) + maxElemPerDepth.emplace_back(std::stoi(elem)); + + auto subModuleType = sm.str(5); + auto subModuleArguments = util::split(sm.str(6), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int inSize = std::stoi(sm.str(7)); + int outSize = std::stoi(sm.str(8)); + + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + + for (unsigned int i = 0; i < maxElemPerDepth.size(); i++) + { + std::string name = fmt::format("{}_{}", i, subModuleType); + if (subModuleType == "LSTM") + depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options))); + else if (subModuleType == "GRU") + depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + } + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); } -torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input) +torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input) { - auto context = input.narrow(1, firstInputIndex, getInputSize()); + auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())); std::vector<torch::Tensor> outputs; @@ -17,14 +58,14 @@ torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input) for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++) for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) { - outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)}))); + outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)}))); offset += maxElemPerDepth[depth]*columns.size(); } return torch::cat(outputs, 1); } -std::size_t DepthLayerTreeEmbeddingModule::getOutputSize() +std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize() { std::size_t outputSize = 0; @@ -34,7 +75,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getOutputSize() return outputSize*(focusedBuffer.size()+focusedStack.size()); } -std::size_t DepthLayerTreeEmbeddingModule::getInputSize() +std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize() { int inputSize = 0; for (int maxElem : maxElemPerDepth) @@ -42,7 +83,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 717bc747c85be8521271d80499698af525741585..08cb9eb5aefc7589cebfac18aaebcfcf1204c69a 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -1,30 +1,67 @@ #include "FocusedColumnModule.hpp" -FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements) +FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition) { - myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); + std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + column = sm.str(1); + maxNbElements = std::stoi(sm.str(2)); + + for (auto & index : util::split(sm.str(3), ' ')) + focusedBuffer.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(4), ' ')) + focusedStack.emplace_back(std::stoi(index)); + + auto subModuleType = sm.str(5); + auto subModuleArguments = util::split(sm.str(6), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int inSize = std::stoi(sm.str(7)); + int outSize = std::stoi(sm.str(8)); + + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); } -torch::Tensor FocusedColumnModule::forward(torch::Tensor input) +torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input) { std::vector<torch::Tensor> outputs; for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++) - outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))); + outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)))); return torch::cat(outputs, 1); } -std::size_t FocusedColumnModule::getOutputSize() +std::size_t FocusedColumnModuleImpl::getOutputSize() { return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements); } -std::size_t FocusedColumnModule::getInputSize() +std::size_t FocusedColumnModuleImpl::getInputSize() { return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { std::vector<long> focusedIndexes; diff --git a/torch_modules/src/GRUNetwork.cpp b/torch_modules/src/GRUNetwork.cpp deleted file mode 100644 index dc7f36609d5a83f94486ee30dc3025bc30ca5bcc..0000000000000000000000000000000000000000 --- a/torch_modules/src/GRUNetwork.cpp +++ /dev/null @@ -1,120 +0,0 @@ -#include "GRUNetwork.hpp" - -GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) -{ -// MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false}; -// auto gruOptionsAll = gruOptions; -// std::get<4>(gruOptionsAll) = true; -// -// int currentOutputSize = embeddingsSize; -// int currentInputSize = 1; -// -// contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold)); -// contextGRU->setFirstInputIndex(currentInputSize); -// currentOutputSize += contextGRU->getOutputSize(); -// currentInputSize += contextGRU->getInputSize(); -// -// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) -// { -// rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll)); -// rawInputGRU->setFirstInputIndex(currentInputSize); -// currentOutputSize += rawInputGRU->getOutputSize(); -// currentInputSize += rawInputGRU->getInputSize(); -// } -// -// if (!treeEmbeddingColumns.empty()) -// { -// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions)); -// treeEmbedding->setFirstInputIndex(currentInputSize); -// currentOutputSize += treeEmbedding->getOutputSize(); -// currentInputSize += treeEmbedding->getInputSize(); -// } -// -// splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll)); -// splitTransGRU->setFirstInputIndex(currentInputSize); -// currentOutputSize += splitTransGRU->getOutputSize(); -// currentInputSize += splitTransGRU->getInputSize(); -// -// for (unsigned int i = 0; i < focusedColumns.size(); i++) -// { -// focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions))); -// focusedLstms.back()->setFirstInputIndex(currentInputSize); -// currentOutputSize += focusedLstms.back()->getOutputSize(); -// currentInputSize += focusedLstms.back()->getInputSize(); -// } -// -// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); -// if (drop2d) -// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); -// else -// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); -// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); -// -// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); -// -// for (auto & it : nbOutputsPerState) -// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); -} - -torch::Tensor GRUNetworkImpl::forward(torch::Tensor input) -{ - return input; -// if (input.dim() == 1) -// input = input.unsqueeze(0); -// -// auto embeddings = wordEmbeddings(input); -// if (embeddingsDropout2d.is_empty()) -// embeddings = embeddingsDropout(embeddings); -// else -// embeddings = embeddingsDropout2d(embeddings); -// -// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; -// -// outputs.emplace_back(contextGRU(embeddings)); -// -// if (!rawInputGRU.is_empty()) -// outputs.emplace_back(rawInputGRU(embeddings)); -// -// if (!treeEmbedding.is_empty()) -// outputs.emplace_back(treeEmbedding(embeddings)); -// -// outputs.emplace_back(splitTransGRU(embeddings)); -// -// for (auto & gru : focusedLstms) -// outputs.emplace_back(gru(embeddings)); -// -// auto totalInput = inputDropout(torch::cat(outputs, 1)); -// -// return outputLayersPerState.at(getState())(mlp(totalInput)); -} - -std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const -{ - std::vector<std::vector<long>> context; - return context; -// if (dict.size() >= maxNbEmbeddings) -// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); -// -// context.emplace_back(); -// -// context.back().emplace_back(dict.getIndexOrInsert(config.getState())); -// -// contextGRU->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!rawInputGRU.is_empty()) -// rawInputGRU->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!treeEmbedding.is_empty()) -// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); -// -// splitTransGRU->addToContext(context, dict, config, mustSplitUnknown()); -// -// for (auto & gru : focusedLstms) -// gru->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!mustSplitUnknown() && context.size() > 1) -// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); -// -// return context; -} - diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp deleted file mode 100644 index e46d175aec9d5e83f1d365d40cdd0b7e5c3dc977..0000000000000000000000000000000000000000 --- a/torch_modules/src/LSTMNetwork.cpp +++ /dev/null @@ -1,120 +0,0 @@ -#include "LSTMNetwork.hpp" - -LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) -{ -// MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false}; -// auto moduleOptionsAll = moduleOptions; -// std::get<4>(moduleOptionsAll) = true; -// -// int currentOutputSize = embeddingsSize; -// int currentInputSize = 1; -// -// contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold)); -// contextLSTM->setFirstInputIndex(currentInputSize); -// currentOutputSize += contextLSTM->getOutputSize(); -// currentInputSize += contextLSTM->getInputSize(); -// -// if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) -// { -// rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll)); -// rawInputLSTM->setFirstInputIndex(currentInputSize); -// currentOutputSize += rawInputLSTM->getOutputSize(); -// currentInputSize += rawInputLSTM->getInputSize(); -// } -// -// if (!treeEmbeddingColumns.empty()) -// { -// treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions)); -// treeEmbedding->setFirstInputIndex(currentInputSize); -// currentOutputSize += treeEmbedding->getOutputSize(); -// currentInputSize += treeEmbedding->getInputSize(); -// } -// -// splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll)); -// splitTransLSTM->setFirstInputIndex(currentInputSize); -// currentOutputSize += splitTransLSTM->getOutputSize(); -// currentInputSize += splitTransLSTM->getInputSize(); -// -// for (unsigned int i = 0; i < focusedColumns.size(); i++) -// { -// focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions))); -// focusedLstms.back()->setFirstInputIndex(currentInputSize); -// currentOutputSize += focusedLstms.back()->getOutputSize(); -// currentInputSize += focusedLstms.back()->getInputSize(); -// } -// -// wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); -// if (drop2d) -// embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); -// else -// embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); -// inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); -// -// mlp = register_module("mlp", MLP(currentOutputSize, mlpParams)); -// -// for (auto & it : nbOutputsPerState) -// outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second))); -} - -torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) -{ - return input; -// if (input.dim() == 1) -// input = input.unsqueeze(0); -// -// auto embeddings = wordEmbeddings(input); -// if (embeddingsDropout2d.is_empty()) -// embeddings = embeddingsDropout(embeddings); -// else -// embeddings = embeddingsDropout2d(embeddings); -// -// std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; -// -// outputs.emplace_back(contextLSTM(embeddings)); -// -// if (!rawInputLSTM.is_empty()) -// outputs.emplace_back(rawInputLSTM(embeddings)); -// -// if (!treeEmbedding.is_empty()) -// outputs.emplace_back(treeEmbedding(embeddings)); -// -// outputs.emplace_back(splitTransLSTM(embeddings)); -// -// for (auto & lstm : focusedLstms) -// outputs.emplace_back(lstm(embeddings)); -// -// auto totalInput = inputDropout(torch::cat(outputs, 1)); -// -// return outputLayersPerState.at(getState())(mlp(totalInput)); -} - -std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const -{ - std::vector<std::vector<long>> context; - return context; -// if (dict.size() >= maxNbEmbeddings) -// util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); -// -// context.emplace_back(); -// -// context.back().emplace_back(dict.getIndexOrInsert(config.getState())); -// -// contextLSTM->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!rawInputLSTM.is_empty()) -// rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!treeEmbedding.is_empty()) -// treeEmbedding->addToContext(context, dict, config, mustSplitUnknown()); -// -// splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown()); -// -// for (auto & lstm : focusedLstms) -// lstm->addToContext(context, dict, config, mustSplitUnknown()); -// -// if (!mustSplitUnknown() && context.size() > 1) -// util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size())); -// -// return context; -} - diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 21edb84ffb337b5f0bf544e5bf71580153888335..47bf5b16d285451522eccdc2c5708ea4b937724b 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -23,6 +23,14 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu std::string name = fmt::format("{}_{}", modules.size(), splited.first); if (splited.first == "Context") modules.emplace_back(register_module(name, ContextModule(splited.second))); + else if (splited.first == "Focused") + modules.emplace_back(register_module(name, FocusedColumnModule(splited.second))); + else if (splited.first == "RawInput") + modules.emplace_back(register_module(name, RawInputModule(splited.second))); + else if (splited.first == "SplitTrans") + modules.emplace_back(register_module(name, SplitTransModule(Config::maxNbAppliableSplitTransitions, splited.second))); + else if (splited.first == "DepthLayerTree") + modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(splited.second))); else if (splited.first == "MLP") { mlpDef = splited.second; diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index e14451dd206b2dea8a3d50f91a603fca97b74e26..9c5e5412bfaf929eaab7455d229208e77f4cc599 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -1,26 +1,57 @@ #include "RawInputModule.hpp" -RawInputModule::RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : leftWindow(leftWindow), rightWindow(rightWindow) +RawInputModuleImpl::RawInputModuleImpl(const std::string & definition) { - myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); + std::regex regex("(?:(?:\\s|\\t)*)Left\\{(.*)\\}(?:(?:\\s|\\t)*)Right\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + leftWindow = std::stoi(sm.str(1)); + rightWindow = std::stoi(sm.str(2)); + + auto subModuleType = sm.str(3); + auto subModuleArguments = util::split(sm.str(4), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int inSize = std::stoi(sm.str(5)); + int outSize = std::stoi(sm.str(6)); + + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); } -torch::Tensor RawInputModule::forward(torch::Tensor input) +torch::Tensor RawInputModuleImpl::forward(torch::Tensor input) { - return myModule->forward(input.narrow(1, firstInputIndex, getInputSize())); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); } -std::size_t RawInputModule::getOutputSize() +std::size_t RawInputModuleImpl::getOutputSize() { return myModule->getOutputSize(leftWindow + rightWindow + 1); } -std::size_t RawInputModule::getInputSize() +std::size_t RawInputModuleImpl::getInputSize() { return leftWindow + rightWindow + 1; } -void RawInputModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { if (leftWindow < 0 or rightWindow < 0) return; diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 45a48df4488e884af925d6b50bfca69bdba73e60..ab1276c10d9e16eedec4638129b0287f0001a0e9 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -1,27 +1,55 @@ #include "SplitTransModule.hpp" #include "Transition.hpp" -SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans) +SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & definition) { - myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); + std::regex regex("(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + auto subModuleType = sm.str(1); + auto subModuleArguments = util::split(sm.str(2), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int inSize = std::stoi(sm.str(3)); + int outSize = std::stoi(sm.str(4)); + + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize))); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); } -torch::Tensor SplitTransModule::forward(torch::Tensor input) +torch::Tensor SplitTransModuleImpl::forward(torch::Tensor input) { - return myModule->forward(input.narrow(1, firstInputIndex, getInputSize())); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); } -std::size_t SplitTransModule::getOutputSize() +std::size_t SplitTransModuleImpl::getOutputSize() { return myModule->getOutputSize(maxNbTrans); } -std::size_t SplitTransModule::getInputSize() +std::size_t SplitTransModuleImpl::getInputSize() { return maxNbTrans; } -void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const +void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const { auto & splitTransitions = config.getAppliableSplitTransitions(); for (auto & contextElement : context)