From b6fda44aaa039d647259991d47f411f87ce4b1c2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 16 Feb 2020 23:30:30 +0100 Subject: [PATCH] Forced EOS at the end of decode --- decoder/src/Decoder.cpp | 12 ++++++++++++ reading_machine/include/TransitionSet.hpp | 1 + reading_machine/src/Action.cpp | 8 ++++---- reading_machine/src/Config.cpp | 4 ++-- reading_machine/src/Strategy.cpp | 2 +- reading_machine/src/TransitionSet.cpp | 10 ++++++++++ trainer/src/Trainer.cpp | 2 ++ 7 files changed, 32 insertions(+), 7 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index bcb25bd..6f209dd 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -40,6 +40,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (debug) + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; @@ -48,6 +50,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) util::myThrow("Cannot move word index !"); } } catch(std::exception & e) {util::myThrow(e.what());} + + // Force EOS when needed + if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) + { + Action shift = Action::pushWordIndexOnStack(); + shift.apply(config, shift); + machine.getTransitionSet().getTransition("EOS")->apply(config); + if (debug) + fmt::print(stderr, "Forcing EOS transition\n"); + } } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index df9551c..4263ba4 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -20,6 +20,7 @@ class TransitionSet Transition * getBestAppliableTransition(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; Transition * getTransition(std::size_t index); + Transition * getTransition(const std::string & name); std::size_t size() const; }; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index c15d638..427ec98 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -203,7 +203,7 @@ Action Action::setRoot() { int rootIndex = -1; - for (int i = config.getWordIndex()-1; true; --i) + for (int i = config.getStack(0); true; --i) { if (!config.has(0, i, 0)) { @@ -224,7 +224,7 @@ Action Action::setRoot() } } - for (int i = config.getWordIndex()-1; true; --i) + for (int i = config.getStack(0); true; --i) { if (!config.has(0, i, 0)) { @@ -276,7 +276,7 @@ Action Action::updateIds() auto apply = [](Config & config, Action & a) { int firstIndexOfSentence = -1; - for (int i = config.getWordIndex()-1; true; --i) + for (int i = config.getStack(0); true; --i) { if (!config.has(0, i, 0)) { @@ -296,7 +296,7 @@ Action Action::updateIds() if (firstIndexOfSentence < 0) util::myThrow("could not find any token in current sentence"); - for (unsigned int i = firstIndexOfSentence, currentId = 1; i < config.getWordIndex(); ++i) + for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i) { if (!config.isToken(i)) continue; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 58426e4..49ad509 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -428,7 +428,7 @@ std::size_t Config::getStack(int relativeIndex) const bool Config::hasHistory(int relativeIndex) const { - return relativeIndex > 0 && relativeIndex < (int)history.size(); + return relativeIndex >= 0 && relativeIndex < (int)history.size(); } bool Config::hasStack(int relativeIndex) const @@ -451,7 +451,7 @@ bool Config::stateIsDone() const if (!rawInput.empty()) return rawInputOnlySeparatorsLeft(); - return !has(0, wordIndex+1, 0); + return !has(0, wordIndex+1, 0) and !hasStack(0); } std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index 14853d3..ee1efa6 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -82,7 +82,7 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); if (!c.stateIsDone()) - return {c.getState(), movement}; + return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0}; if (!isDone[target]) return {target, -c.getWordIndex()}; diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index e7878ac..d5b9716 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -85,3 +85,13 @@ Transition * TransitionSet::getTransition(std::size_t index) return &transitions[index]; } +Transition * TransitionSet::getTransition(const std::string & name) +{ + for (auto & transition : transitions) + if (transition.getName() == name) + return &transition; + + return nullptr; +} + + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 595c70b..195cc5c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -38,6 +38,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (debug) + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; -- GitLab