diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 013a097e78cfab2566d5a8b141a68e612e8c61af..2e2b51dd163140113d0f8d5d8ac95edb4b548a1b 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -15,11 +15,12 @@ class Classifier private : - void initNeuralNetwork(const std::string & topology); + void initNeuralNetwork(const std::vector<std::string> & definition); + void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex); public : - Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFile); + Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition); TransitionSet & getTransitionSet(); NeuralNetwork & getNN(); const std::string & getName() const; diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 20a505639a96547459148c2752af5602b0b06e08..af823b77e5fbaeed858785d087457cb98ac16a58 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -3,11 +3,21 @@ #include "LSTMNetwork.hpp" #include "RandomNetwork.hpp" -Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles) +Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition) { this->name = name; - this->transitionSet.reset(new TransitionSet(tsFiles)); - initNeuralNetwork(topology); + if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) + { + std::vector<std::string> tsFiles; + + for (auto & tsFilename : util::split(sm.str(1), ' ')) + tsFiles.emplace_back(path.parent_path() / tsFilename); + + this->transitionSet.reset(new TransitionSet(tsFiles)); + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); + + initNeuralNetwork(definition); } TransitionSet & Classifier::getTransitionSet() @@ -25,66 +35,177 @@ const std::string & Classifier::getName() const return name; } -void Classifier::initNeuralNetwork(const std::string & topology) +void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) +{ + std::size_t curIndex = 1; + + std::string networkType; + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) + { + networkType = sm.str(1); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType")); + + if (networkType == "Random") + this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); + else if (networkType == "LSTM") + initLSTM(definition, curIndex); + else + util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); +} + +void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) { - static std::vector<std::tuple<std::regex, std::string, std::function<void(const std::smatch &)>>> initializers - { - { - std::regex("Random"), - "Random : Output is chosen at random.", - [this,topology](auto sm) - { - this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); - } - }, - { - std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), - "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", - [this,topology](auto sm) - { - std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; - std::vector<std::string> focusedColumns, columns; - for (auto s : util::split(sm.str(2), ',')) - bufferContext.emplace_back(std::stoi(s)); - for (auto s : util::split(sm.str(3), ',')) - stackContext.emplace_back(std::stoi(s)); - for (auto s : util::split(sm.str(4), ',')) - columns.emplace_back(s); - for (auto s : util::split(sm.str(5), ',')) - focusedBuffer.push_back(std::stoi(s)); - for (auto s : util::split(sm.str(6), ',')) - focusedStack.push_back(std::stoi(s)); - for (auto s : util::split(sm.str(7), ',')) - focusedColumns.emplace_back(s); - for (auto s : util::split(sm.str(8), ',')) - maxNbElements.push_back(std::stoi(s)); - if (focusedColumns.size() != maxNbElements.size()) - util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); - } - }, - }; - - std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); - - for (auto & initializer : initializers) - try - { - if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer))) - { - this->nn->to(NeuralNetworkImpl::device); - return; - } - } - catch (std::exception & e) - { - errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage); - break; - } - - for (auto & initializer : initializers) - errorMessage += std::get<1>(initializer) + "\n"; - - util::myThrow(errorMessage); + int unknownValueThreshold; + std::vector<int> bufferContext, stackContext; + std::vector<std::string> columns; + std::vector<int> focusedBuffer, focusedStack; + std::vector<std::string> focusedColumns; + std::vector<int> maxNbElements; + int rawInputLeftWindow, rawInputRightWindow; + int embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers; + bool bilstm; + float lstmDropout; + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) + { + unknownValueThreshold = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + bufferContext.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + stackContext.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm) + { + columns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + focusedBuffer.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + focusedStack.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm) + { + focusedColumns = util::split(sm.str(1), ' '); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm) + { + for (auto & index : util::split(sm.str(1), ' ')) + maxNbElements.emplace_back(std::stoi(index)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm) + { + rawInputLeftWindow = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm) + { + rawInputRightWindow = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm) + { + embeddingsSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Hidden size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&hiddenSize](auto sm) + { + hiddenSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Hidden size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm) + { + contextLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm) + { + focusedLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm) + { + rawInputLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm) + { + splitTransLSTMSize = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm) + { + nbLayers = std::stoi(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm) + { + bilstm = sm.str(1) == "true"; + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false")); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm) + { + lstmDropout = std::stof(sm.str(1)); + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value")); + + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout)); } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 94de1cbf43e15167fddad47b9ed6b043a0f08651..9bd3e2f22c5b779f040d696f01eeacdb7fcc71d3 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -57,18 +57,23 @@ void ReadingMachine::readFromFile(std::filesystem::path path) if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) util::myThrow("No name specified"); - while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) \\{(.+)\\}"), lines[curLine++], [this,path](auto sm) + while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm) { - std::vector<std::string> tsFiles = util::split(sm.str(3), ' '); - for (auto & tsFile : tsFiles) - tsFile = path.parent_path() / tsFile; - classifier.reset(new Classifier(sm.str(1), sm.str(2), tsFiles)); + std::vector<std::string> classifierDefinition; + if (lines[curLine] != "{") + util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); + + for (curLine++; curLine < lines.size(); curLine++) + { + if (lines[curLine] == "}") + break; + classifierDefinition.emplace_back(lines[curLine]); + } + classifier.reset(new Classifier(sm.str(1), path, classifierDefinition)); })); if (!classifier.get()) util::myThrow("No Classifier specified"); - --curLine; - util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) { this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 5762ad1ad5c32211bf190d05d6c259a08d793cb6..6f2b46fc2cb73bf44ff73a7c8413175d4e772451 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -27,7 +27,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 430b2267ce1ae3d4dba987751749e1b22a527c76..cfa8c458679bfefea76ea62b239c08fa63428b5e 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,14 +1,8 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout) { - constexpr int embeddingsSize = 256; - constexpr int hiddenSize = 8192; - constexpr int contextLSTMSize = 1024; - constexpr int focusedLSTMSize = 256; - constexpr int rawInputLSTMSize = 32; - - LSTMImpl::LSTMOptions lstmOptions{true,true,2,0.3,false}; + LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; std::get<4>(lstmOptionsAll) = true; @@ -29,7 +23,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: currentInputSize += rawInputLSTM->getInputSize(); } - splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, embeddingsSize, lstmOptionsAll)); + splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll)); splitTransLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += splitTransLSTM->getOutputSize(); currentInputSize += splitTransLSTM->getInputSize();