From b9bb67a5b553b5a072467aa146fc4263bf72c3ec Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 23 May 2020 12:58:58 +0200 Subject: [PATCH] Strategy is now part of Config --- decoder/src/Decoder.cpp | 22 +++++++++++++++------- reading_machine/include/Config.hpp | 4 ++++ reading_machine/include/ReadingMachine.hpp | 5 ++--- reading_machine/include/Strategy.hpp | 5 ++++- reading_machine/src/Config.cpp | 13 +++++++++++++ reading_machine/src/ReadingMachine.cpp | 7 +++---- reading_machine/src/Strategy.cpp | 1 + trainer/src/Trainer.cpp | 16 ++++++++-------- 8 files changed, 50 insertions(+), 23 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 94dd16b..1c68a0d 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,12 +7,13 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) { + constexpr float beamThreshold = 0.1; + constexpr int printInterval = 50; + torch::AutoGradMode useGrad(false); machine.trainMode(false); machine.setDictsState(Dict::State::Closed); - machine.getStrategy().reset(); - constexpr int printInterval = 50; int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); @@ -22,16 +23,15 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, try { - baseConfig.addPredicted(machine.getPredicted()); - baseConfig.setState(machine.getStrategy().getInitialState()); for (unsigned int i = 0; i < beamSize; i++) { beam.emplace_back(baseConfig); + beam.back().setStrategy(machine.getStrategyDefinition()); + beam.back().addPredicted(machine.getPredicted()); + beam.back().setState(beam.back().getStrategy().getInitialState()); endFlag.emplace_back(false); } - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); - while (true) { if (machine.hasSplitWordTransitionSet()) @@ -41,6 +41,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, std::vector<std::vector<int>> appliableTransitions; for (auto & c : beam) { + machine.getClassifier()->setState(c.getState()); appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c)); c.setAppliableTransitions(appliableTransitions.back()); } @@ -61,6 +62,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) { auto & c = beam[beamIndex]; + machine.getClassifier()->setState(c.getState()); c.printForDebug(stderr); auto softmaxed = torch::softmax(predictions[beamIndex],-1); std::vector<std::pair<float,std::string>> toPrint; @@ -83,8 +85,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, if (endFlag[beamIndex]) continue; auto & c = beam[beamIndex]; + machine.getClassifier()->setState(c.getState()); int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); + auto softmaxed = torch::softmax(predictions[beamIndex], 0); + std::vector<int> consideredTransitions; try { @@ -96,7 +101,10 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, chosenTransition = i; bestScore = score; } + if (softmaxed[i].item<float>() >= beamThreshold) + consideredTransitions.emplace_back(i); } + } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1) @@ -110,7 +118,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, transition->apply(c); c.addToHistory(transition->getName()); - auto movement = machine.getStrategy().getMovement(c, transition->getName()); + auto movement = c.getStrategy().getMovement(c, transition->getName()); if (debug) { //TODO improve this for beam search diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index b9de688..4a00317 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -8,6 +8,7 @@ #include <boost/circular_buffer.hpp> #include "util.hpp" #include "Dict.hpp" +#include "Strategy.hpp" class Transition; @@ -49,6 +50,7 @@ class Config int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; std::vector<int> appliableTransitions; + std::shared_ptr<Strategy> strategy; protected : @@ -156,6 +158,8 @@ class Config const std::vector<Transition *> & getAppliableSplitTransitions() const; const std::vector<int> & getAppliableTransitions() const; bool isExtraColumn(const std::string & colName) const; + void setStrategy(std::vector<std::string> & strategyDefinition); + Strategy & getStrategy(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index f63c21a..ab23fdc 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -4,7 +4,6 @@ #include <filesystem> #include <memory> #include "Classifier.hpp" -#include "Strategy.hpp" class ReadingMachine { @@ -19,7 +18,7 @@ class ReadingMachine std::string name; std::filesystem::path path; std::unique_ptr<Classifier> classifier; - std::unique_ptr<Strategy> strategy; + std::vector<std::string> strategyDefinition; std::set<std::string> predicted; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; @@ -36,7 +35,7 @@ class ReadingMachine TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); bool hasSplitWordTransitionSet() const; - Strategy & getStrategy(); + std::vector<std::string> & getStrategyDefinition(); Classifier * getClassifier(); bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 79fc26c..c907b71 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -1,7 +1,10 @@ #ifndef STRATEGY__H #define STRATEGY__H -#include "Config.hpp" +#include <string> +#include <vector> + +class Config; class Strategy { diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 4a218a5..edf1a06 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -727,3 +727,16 @@ std::size_t Config::getStackSize() const return stack.size(); } +void Config::setStrategy(std::vector<std::string> & strategyDefinition) +{ + strategy.reset(new Strategy(strategyDefinition)); +} + +Strategy & Config::getStrategy() +{ + if (strategy.get() == nullptr) + util::myThrow("strategy was not set"); + + return *strategy.get(); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 5ef6c43..edecf08 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -90,7 +90,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) { - std::vector<std::string> strategyDefinition; + strategyDefinition.clear(); if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); @@ -100,7 +100,6 @@ void ReadingMachine::readFromFile(std::filesystem::path path) break; strategyDefinition.emplace_back(lines[curLine]); } - strategy.reset(new Strategy(strategyDefinition)); })) util::myThrow("No Strategy specified"); @@ -122,9 +121,9 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet() return *splitWordTransitionSet; } -Strategy & ReadingMachine::getStrategy() +std::vector<std::string> & ReadingMachine::getStrategyDefinition() { - return *strategy; + return strategyDefinition; } Classifier * ReadingMachine::getClassifier() diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index af65052..bcdd951 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -1,4 +1,5 @@ #include "Strategy.hpp" +#include "Config.hpp" Strategy::Strategy(std::vector<std::string> definition) { diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index eb76934..72b56e6 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -41,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); - machine.getStrategy().reset(); - config.setState(machine.getStrategy().getInitialState()); - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier()->setState(config.getState()); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); @@ -132,7 +132,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p transition->apply(config); config.addToHistory(transition->getName()); - auto movement = machine.getStrategy().getMovement(config, transition->getName()); + auto movement = config.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) @@ -291,9 +291,9 @@ void Trainer::fillDicts(SubConfig & config, bool debug) torch::AutoGradMode useGrad(false); config.addPredicted(machine.getPredicted()); - machine.getStrategy().reset(); - config.setState(machine.getStrategy().getInitialState()); - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier()->setState(config.getState()); while (true) { @@ -325,7 +325,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug) goldTransition->apply(config); config.addToHistory(goldTransition->getName()); - auto movement = machine.getStrategy().getMovement(config, goldTransition->getName()); + auto movement = config.getStrategy().getMovement(config, goldTransition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) -- GitLab