diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 94dd16beb55786200195827d18f708e0202a188a..1c68a0d8db414b9bd8a9fbd02692d2fe8ea5fce9 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,12 +7,13 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) { + constexpr float beamThreshold = 0.1; + constexpr int printInterval = 50; + torch::AutoGradMode useGrad(false); machine.trainMode(false); machine.setDictsState(Dict::State::Closed); - machine.getStrategy().reset(); - constexpr int printInterval = 50; int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); @@ -22,16 +23,15 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, try { - baseConfig.addPredicted(machine.getPredicted()); - baseConfig.setState(machine.getStrategy().getInitialState()); for (unsigned int i = 0; i < beamSize; i++) { beam.emplace_back(baseConfig); + beam.back().setStrategy(machine.getStrategyDefinition()); + beam.back().addPredicted(machine.getPredicted()); + beam.back().setState(beam.back().getStrategy().getInitialState()); endFlag.emplace_back(false); } - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); - while (true) { if (machine.hasSplitWordTransitionSet()) @@ -41,6 +41,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, std::vector<std::vector<int>> appliableTransitions; for (auto & c : beam) { + machine.getClassifier()->setState(c.getState()); appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c)); c.setAppliableTransitions(appliableTransitions.back()); } @@ -61,6 +62,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) { auto & c = beam[beamIndex]; + machine.getClassifier()->setState(c.getState()); c.printForDebug(stderr); auto softmaxed = torch::softmax(predictions[beamIndex],-1); std::vector<std::pair<float,std::string>> toPrint; @@ -83,8 +85,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, if (endFlag[beamIndex]) continue; auto & c = beam[beamIndex]; + machine.getClassifier()->setState(c.getState()); int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); + auto softmaxed = torch::softmax(predictions[beamIndex], 0); + std::vector<int> consideredTransitions; try { @@ -96,7 +101,10 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, chosenTransition = i; bestScore = score; } + if (softmaxed[i].item<float>() >= beamThreshold) + consideredTransitions.emplace_back(i); } + } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1) @@ -110,7 +118,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, transition->apply(c); c.addToHistory(transition->getName()); - auto movement = machine.getStrategy().getMovement(c, transition->getName()); + auto movement = c.getStrategy().getMovement(c, transition->getName()); if (debug) { //TODO improve this for beam search diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index b9de6887e894852d3716d994fa9e63e3e548e71f..4a003176a8c31bdca21f4097831a865e830ea4c6 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -8,6 +8,7 @@ #include <boost/circular_buffer.hpp> #include "util.hpp" #include "Dict.hpp" +#include "Strategy.hpp" class Transition; @@ -49,6 +50,7 @@ class Config int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; std::vector<int> appliableTransitions; + std::shared_ptr<Strategy> strategy; protected : @@ -156,6 +158,8 @@ class Config const std::vector<Transition *> & getAppliableSplitTransitions() const; const std::vector<int> & getAppliableTransitions() const; bool isExtraColumn(const std::string & colName) const; + void setStrategy(std::vector<std::string> & strategyDefinition); + Strategy & getStrategy(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index f63c21aa7899250fdf11865ce6d3e0fe36a5b4ba..ab23fdcac68f0720f3d47b27b2c3b1df87de9337 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -4,7 +4,6 @@ #include <filesystem> #include <memory> #include "Classifier.hpp" -#include "Strategy.hpp" class ReadingMachine { @@ -19,7 +18,7 @@ class ReadingMachine std::string name; std::filesystem::path path; std::unique_ptr<Classifier> classifier; - std::unique_ptr<Strategy> strategy; + std::vector<std::string> strategyDefinition; std::set<std::string> predicted; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; @@ -36,7 +35,7 @@ class ReadingMachine TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); bool hasSplitWordTransitionSet() const; - Strategy & getStrategy(); + std::vector<std::string> & getStrategyDefinition(); Classifier * getClassifier(); bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 79fc26cbc917d59ce17f93024f33dc60652e6b88..c907b71d4424d222076564d0a1e16a6f69c745f3 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -1,7 +1,10 @@ #ifndef STRATEGY__H #define STRATEGY__H -#include "Config.hpp" +#include <string> +#include <vector> + +class Config; class Strategy { diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 4a218a5daefc7eed4398c67f5ec14429896c1f3c..edf1a067a498321f0ba858437acfd1410c52dd34 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -727,3 +727,16 @@ std::size_t Config::getStackSize() const return stack.size(); } +void Config::setStrategy(std::vector<std::string> & strategyDefinition) +{ + strategy.reset(new Strategy(strategyDefinition)); +} + +Strategy & Config::getStrategy() +{ + if (strategy.get() == nullptr) + util::myThrow("strategy was not set"); + + return *strategy.get(); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 5ef6c43e235e673fe072c5f6257b52669bf10b44..edecf08befdfa60d32aa734e9c27e540441c1f65 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -90,7 +90,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) { - std::vector<std::string> strategyDefinition; + strategyDefinition.clear(); if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); @@ -100,7 +100,6 @@ void ReadingMachine::readFromFile(std::filesystem::path path) break; strategyDefinition.emplace_back(lines[curLine]); } - strategy.reset(new Strategy(strategyDefinition)); })) util::myThrow("No Strategy specified"); @@ -122,9 +121,9 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet() return *splitWordTransitionSet; } -Strategy & ReadingMachine::getStrategy() +std::vector<std::string> & ReadingMachine::getStrategyDefinition() { - return *strategy; + return strategyDefinition; } Classifier * ReadingMachine::getClassifier() diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index af65052e489dca55c0fa08b5f38ccc65019c9b11..bcdd951bafe4e98ead8f60b5275c6d129660f432 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -1,4 +1,5 @@ #include "Strategy.hpp" +#include "Config.hpp" Strategy::Strategy(std::vector<std::string> definition) { diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index eb76934d0b58ddc14f8f392776eeef0b07de5c1b..72b56e6a00e011a952a8f5f1a3de348a52aaf710 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -41,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); - machine.getStrategy().reset(); - config.setState(machine.getStrategy().getInitialState()); - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier()->setState(config.getState()); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); @@ -132,7 +132,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p transition->apply(config); config.addToHistory(transition->getName()); - auto movement = machine.getStrategy().getMovement(config, transition->getName()); + auto movement = config.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) @@ -291,9 +291,9 @@ void Trainer::fillDicts(SubConfig & config, bool debug) torch::AutoGradMode useGrad(false); config.addPredicted(machine.getPredicted()); - machine.getStrategy().reset(); - config.setState(machine.getStrategy().getInitialState()); - machine.getClassifier()->setState(machine.getStrategy().getInitialState()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); + machine.getClassifier()->setState(config.getState()); while (true) { @@ -325,7 +325,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug) goldTransition->apply(config); config.addToHistory(goldTransition->getName()); - auto movement = machine.getStrategy().getMovement(config, goldTransition->getName()); + auto movement = config.getStrategy().getMovement(config, goldTransition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement)