From b6fda44aaa039d647259991d47f411f87ce4b1c2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 16 Feb 2020 23:30:30 +0100
Subject: [PATCH] Forced EOS at the end of decode

---
 decoder/src/Decoder.cpp                   | 12 ++++++++++++
 reading_machine/include/TransitionSet.hpp |  1 +
 reading_machine/src/Action.cpp            |  8 ++++----
 reading_machine/src/Config.cpp            |  4 ++--
 reading_machine/src/Strategy.cpp          |  2 +-
 reading_machine/src/TransitionSet.cpp     | 10 ++++++++++
 trainer/src/Trainer.cpp                   |  2 ++
 7 files changed, 32 insertions(+), 7 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index bcb25bd..6f209dd 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -40,6 +40,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
     config.addToHistory(transition->getName());
 
     auto movement = machine.getStrategy().getMovement(config, transition->getName());
+    if (debug)
+      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
     if (movement == Strategy::endMovement)
       break;
 
@@ -48,6 +50,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
       util::myThrow("Cannot move word index !");
   }
   } catch(std::exception & e) {util::myThrow(e.what());}
+
+  // Force EOS when needed
+  if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
+  {
+    Action shift = Action::pushWordIndexOnStack();
+    shift.apply(config, shift);
+    machine.getTransitionSet().getTransition("EOS")->apply(config);
+    if (debug)
+      fmt::print(stderr, "Forcing EOS transition\n");
+  }
 }
 
 float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index df9551c..4263ba4 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -20,6 +20,7 @@ class TransitionSet
   Transition * getBestAppliableTransition(const Config & c);
   std::size_t getTransitionIndex(const Transition * transition) const;
   Transition * getTransition(std::size_t index);
+  Transition * getTransition(const std::string & name);
   std::size_t size() const;
 };
 
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index c15d638..427ec98 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -203,7 +203,7 @@ Action Action::setRoot()
   {
     int rootIndex = -1;
 
-    for (int i = config.getWordIndex()-1; true; --i)
+    for (int i = config.getStack(0); true; --i)
     {
       if (!config.has(0, i, 0))
       {
@@ -224,7 +224,7 @@ Action Action::setRoot()
       }
     }
 
-    for (int i = config.getWordIndex()-1; true; --i)
+    for (int i = config.getStack(0); true; --i)
     {
       if (!config.has(0, i, 0))
       {
@@ -276,7 +276,7 @@ Action Action::updateIds()
   auto apply = [](Config & config, Action & a)
   {
     int firstIndexOfSentence = -1;
-    for (int i = config.getWordIndex()-1; true; --i)
+    for (int i = config.getStack(0); true; --i)
     {
       if (!config.has(0, i, 0))
       {
@@ -296,7 +296,7 @@ Action Action::updateIds()
     if (firstIndexOfSentence < 0)
       util::myThrow("could not find any token in current sentence");
 
-    for (unsigned int i = firstIndexOfSentence, currentId = 1; i < config.getWordIndex(); ++i)
+    for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i)
     {
       if (!config.isToken(i))
         continue;
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 58426e4..49ad509 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -428,7 +428,7 @@ std::size_t Config::getStack(int relativeIndex) const
 
 bool Config::hasHistory(int relativeIndex) const
 {
-  return relativeIndex > 0 && relativeIndex < (int)history.size();
+  return relativeIndex >= 0 && relativeIndex < (int)history.size();
 }
 
 bool Config::hasStack(int relativeIndex) const
@@ -451,7 +451,7 @@ bool Config::stateIsDone() const
   if (!rawInput.empty())
     return rawInputOnlySeparatorsLeft();
 
-  return !has(0, wordIndex+1, 0);
+  return !has(0, wordIndex+1, 0) and !hasStack(0);
 }
 
 std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp
index 14853d3..ee1efa6 100644
--- a/reading_machine/src/Strategy.cpp
+++ b/reading_machine/src/Strategy.cpp
@@ -82,7 +82,7 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co
     util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
 
   if (!c.stateIsDone())
-    return {c.getState(), movement};
+    return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
 
   if (!isDone[target])
     return {target, -c.getWordIndex()};
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index e7878ac..d5b9716 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -85,3 +85,13 @@ Transition * TransitionSet::getTransition(std::size_t index)
   return &transitions[index];
 }
 
+Transition * TransitionSet::getTransition(const std::string & name)
+{
+  for (auto & transition : transitions)
+    if (transition.getName() == name)
+      return &transition;
+
+  return nullptr;
+}
+
+
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 595c70b..195cc5c 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -38,6 +38,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
     config.addToHistory(transition->getName());
 
     auto movement = machine.getStrategy().getMovement(config, transition->getName());
+    if (debug)
+      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
     if (movement == Strategy::endMovement)
       break;
 
-- 
GitLab