From 050839edd5b7727159ad8ac8811a1b91452718a4 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 5 Apr 2020 16:14:33 +0200 Subject: [PATCH] Sequential --- decoder/src/Decoder.cpp | 3 +-- reading_machine/include/Classifier.hpp | 2 +- reading_machine/include/Config.hpp | 1 + reading_machine/include/SubConfig.hpp | 7 ++----- reading_machine/include/TransitionSet.hpp | 5 +++++ reading_machine/src/Classifier.cpp | 4 ++-- reading_machine/src/Config.cpp | 24 ++++++++++++++++++++++- reading_machine/src/ReadingMachine.cpp | 8 +++++++- reading_machine/src/SubConfig.cpp | 2 +- reading_machine/src/TransitionSet.cpp | 11 +++++++++++ torch_modules/src/LSTMNetwork.cpp | 14 +++++++++---- trainer/src/MacaonTrain.cpp | 4 ++-- trainer/src/Trainer.cpp | 6 +----- 13 files changed, 67 insertions(+), 24 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 7d563c8..3a2caef 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -81,8 +81,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool break; config.setState(movement.first); - if (!config.moveWordIndex(movement.second)) - util::myThrow("Cannot move word index !"); + config.moveWordIndexRelaxed(movement.second); } } catch(std::exception & e) {util::myThrow(e.what());} diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 1131db7..013a097 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -19,7 +19,7 @@ class Classifier public : - Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); + Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFile); TransitionSet & getTransitionSet(); NeuralNetwork & getNN(); const std::string & getName() const; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index ea092f0..c9b7d4d 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -108,6 +108,7 @@ class Config bool isTokenPredicted(std::size_t lineIndex) const; bool moveWordIndex(int relativeMovement); bool canMoveWordIndex(int relativeMovement) const; + void moveWordIndexRelaxed(int relativeMovement); bool moveCharacterIndex(int relativeMovement); bool canMoveCharacterIndex(int relativeMovement) const; bool rawInputOnlySeparatorsLeft() const; diff --git a/reading_machine/include/SubConfig.hpp b/reading_machine/include/SubConfig.hpp index 2f1efd1..05b2541 100644 --- a/reading_machine/include/SubConfig.hpp +++ b/reading_machine/include/SubConfig.hpp @@ -8,16 +8,13 @@ class SubConfig : public Config { private : - static constexpr std::size_t spanSize = 800; - - private : - const BaseConfig & model; + std::size_t spanSize; std::size_t firstLineIndex{0}; public : - SubConfig(BaseConfig & model); + SubConfig(BaseConfig & model, std::size_t spanSize); bool update(); bool needsUpdate(); std::size_t getNbColumns() const override; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index a1bc2c1..d0c7c1f 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -12,8 +12,13 @@ class TransitionSet std::vector<Transition> transitions; + private : + + void addTransitionsFromFile(const std::string & filename); + public : + TransitionSet(const std::vector<std::string> & filenames); TransitionSet(const std::string & filename); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index a58c3b1..54dccff 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -6,10 +6,10 @@ #include "LSTMNetwork.hpp" #include "RandomNetwork.hpp" -Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile) +Classifier::Classifier(const std::string & name, const std::string & topology, const std::vector<std::string> & tsFiles) { this->name = name; - this->transitionSet.reset(new TransitionSet(tsFile)); + this->transitionSet.reset(new TransitionSet(tsFiles)); initNeuralNetwork(topology); } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index bb2366c..8681cce 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -431,6 +431,28 @@ bool Config::moveWordIndex(int relativeMovement) return true; } +void Config::moveWordIndexRelaxed(int relativeMovement) +{ + int nbMovements = 0; + int increment = relativeMovement > 0 ? 1 : -1; + while (nbMovements != relativeMovement) + { + do + { + if (!has(0,wordIndex+increment,0)) + break; + wordIndex += increment; + } + while (isComment(wordIndex)); + nbMovements += relativeMovement > 0 ? 1 : -1; + } + + if (!isComment(wordIndex)) + return; + + moveWordIndex(-increment); +} + bool Config::canMoveWordIndex(int relativeMovement) const { int nbMovements = 0; @@ -515,7 +537,7 @@ void Config::setState(const std::string state) bool Config::stateIsDone() const { if (!rawInput.empty()) - return rawInputOnlySeparatorsLeft(); + return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0); return !has(0, wordIndex+1, 0) and !hasStack(0); } diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 96b5e71..9d4a854 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -57,7 +57,13 @@ void ReadingMachine::readFromFile(std::filesystem::path path) if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) util::myThrow("No name specified"); - while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) (.+)"), lines[curLine++], [this,path](auto sm){classifier.reset(new Classifier(sm.str(1), sm.str(2), path.parent_path() / sm.str(3)));})); + while (util::doIfNameMatch(std::regex("Classifier : (.+) (.+) \\{(.+)\\}"), lines[curLine++], [this,path](auto sm) + { + std::vector<std::string> tsFiles = util::split(sm.str(3), ' '); + for (auto & tsFile : tsFiles) + tsFile = path.parent_path() / tsFile; + classifier.reset(new Classifier(sm.str(1), sm.str(2), tsFiles)); + })); if (!classifier.get()) util::myThrow("No Classifier specified"); diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp index 1b63ed2..c571b3c 100644 --- a/reading_machine/src/SubConfig.cpp +++ b/reading_machine/src/SubConfig.cpp @@ -1,6 +1,6 @@ #include "SubConfig.hpp" -SubConfig::SubConfig(BaseConfig & model) : Config(model.rawInput), model(model) +SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(model.rawInput), model(model), spanSize(spanSize) { wordIndex = model.wordIndex; characterIndex = model.characterIndex; diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 5d0df94..a6ed1b0 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -2,6 +2,17 @@ #include <limits> TransitionSet::TransitionSet(const std::string & filename) +{ + addTransitionsFromFile(filename); +} + +TransitionSet::TransitionSet(const std::vector<std::string> & filenames) +{ + for (auto & filename : filenames) + addTransitionsFromFile(filename); +} + +void TransitionSet::addTransitionsFromFile(const std::string & filename) { FILE * file = std::fopen(filename.c_str(), "r"); if (!file) diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 201feea..7ab2dc2 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -40,7 +40,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size()); } - linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize)); + linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } @@ -51,16 +51,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); - auto splitTrans = embeddings.narrow(1, 0, Config::maxNbAppliableSplitTransitions); + auto state = embeddings.narrow(1, 0, 1).squeeze(1); - auto context = embeddings.narrow(1, splitTrans.size(1)+rawInputSize, getContextSize()); + auto splitTrans = embeddings.narrow(1, 1, Config::maxNbAppliableSplitTransitions); + + auto context = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize, getContextSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); - auto elementsEmbeddings = embeddings.narrow(1, splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(splitTrans.size(1)+rawInputSize+context.size(1))); + auto elementsEmbeddings = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(1+splitTrans.size(1)+rawInputSize+context.size(1))); std::vector<torch::Tensor> lstmOutputs; + lstmOutputs.emplace_back(state); + if (rawInputSize != 0) { auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize); @@ -110,6 +114,8 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, std::vector<std::vector<long>> context; context.emplace_back(); + context.back().emplace_back(dict.getIndexOrInsert(config.getState())); + auto & splitTransitions = config.getAppliableSplitTransitions(); for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++) if (i < (int)splitTransitions.size()) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 7b8e60f..ffce48f 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -101,7 +101,7 @@ int MacaonTrain::main() BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); - SubConfig config(goldConfig); + SubConfig config(goldConfig, goldConfig.getNbLines()); fillDicts(machine, goldConfig); @@ -109,7 +109,7 @@ int MacaonTrain::main() trainer.createDataset(config, debug); if (!computeDevScore) { - SubConfig devConfig(devGoldConfig); + SubConfig devConfig(devGoldConfig, devGoldConfig.getNbLines()); trainer.createDevDataset(devConfig, debug); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 57e7e51..6459bb3 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -99,11 +99,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: break; config.setState(movement.first); - if (!config.moveWordIndex(movement.second)) - { - config.printForDebug(stderr); - util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); - } + config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); -- GitLab