From 3e58bb021cb3f8be67f53abebb47795f2dea7471 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 25 Nov 2019 01:13:39 +0100
Subject: [PATCH] EOS is now predicted by an other classifier named segmenter

---
 transition_machine/src/Config.cpp |   4 +
 transition_machine/src/Oracle.cpp | 118 +++++++++++++++++++++---------
 2 files changed, 86 insertions(+), 36 deletions(-)

diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp
index 38492f1..90ef227 100644
--- a/transition_machine/src/Config.cpp
+++ b/transition_machine/src/Config.cpp
@@ -941,6 +941,10 @@ void Config::updateIdsInSequence()
   int sentenceEnd = stackHasIndex(0) ? stackGetElem(0) : getHead();
   auto & eos = getTape(ProgramParameters::sequenceDelimiterTape);
   auto & ids = getTape("ID");
+
+  if (getTape("FORM")[sentenceEnd-getHead()].empty())
+    return;
+
   while (sentenceEnd >= 0 && eos.getHyp(sentenceEnd-getHead()) != ProgramParameters::sequenceDelimiter)
     sentenceEnd--;
 
diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp
index 288faa0..0139217 100644
--- a/transition_machine/src/Oracle.cpp
+++ b/transition_machine/src/Oracle.cpp
@@ -526,7 +526,7 @@ void Oracle::createDatabase()
     return 0;
   })));
 
-  str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle(
+  str2oracle.emplace("strategy_parser_legacy", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle *)
   {
   },
@@ -567,6 +567,46 @@ void Oracle::createDatabase()
     return 0;
   })));
 
+  str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle(
+  [](Oracle *)
+  {
+  },
+  [](Config & c, Oracle *)
+  {
+    if (c.pastActions.size() == 0)
+      return std::string("MOVE parser 0");
+
+    std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
+    std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
+    std::string newState;
+    int movement = 0;
+
+    if (previousState == "parser")
+    {
+      newState = "parser";
+      if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
+        newState = "segmenter";
+    }
+    else if (previousState == "segmenter")
+    {
+      newState = "parser";
+      movement = 1;
+    }
+    else if (previousState == "error_parser")
+    {
+      newState = "parser";
+      std::string previousParserAction = util::noAccentLower(c.pastActions.getElem(1).second.name);
+      if (util::split(previousParserAction, ' ')[0] == "shift" || util::split(previousParserAction, ' ')[0] == "right")
+        movement = 1;
+    }
+
+    return "MOVE " + newState + " " + std::to_string(movement);
+  },
+  [](Config &, Oracle *, const std::string &)
+  {
+    return 0;
+  })));
+
   str2oracle.emplace("strategy_tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle *)
   {
@@ -604,14 +644,17 @@ void Oracle::createDatabase()
     {
       if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
       {
-        newState = "tagger";
-        if (c.endOfTapes())
-          newState = "parser";
-        movement = 1;
+        newState = "segmenter";
+        movement = 0;
       }
       else
         newState = "parser";
     }
+    else if (previousState == "segmenter")
+    {
+      newState = "tagger";
+      movement = 1;
+    }
     else
       newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
 
@@ -707,35 +750,36 @@ void Oracle::createDatabase()
     {
       if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
       {
-        newState = "tagger";
-        movement = lastIndexDone[newState]-c.getHead()+1;
+        newState = "segmenter";
+        movement = 0;
         lastIndexDone[previousState] = c.getHead();
+      }
+      else
+        newState = "parser";
+    }
+    else if (previousState == "segmenter")
+    {
+      newState = "tagger";
+      movement = lastIndexDone[newState]-c.getHead()+1;
+      if (lastIndexDone[newState]+1 >= c.getTape("FORM").size())
+      {
+        newState = "morpho";
+        movement = lastIndexDone[newState]-c.getHead()+1;
         if (lastIndexDone[newState]+1 >= c.getTape("FORM").size())
         {
-          newState = "morpho";
-          movement = lastIndexDone[newState]-c.getHead()+1;
-          if (lastIndexDone[newState]+1 >= c.getTape("FORM").size())
+          newState = "lemmatizer_rules";
+          movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1;
+          if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size())
           {
-            newState = "lemmatizer_rules";
-            movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1;
-            if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size())
-            {
-              newState = "parser";
-              movement = lastIndexDone[newState]-c.getHead()+1;
-            }
+            newState = "parser";
+            movement = lastIndexDone[newState]-c.getHead()+1;
           }
         }
-        if (c.endOfTapes())
-        {
-          newState = "parser";
-          movement = 1;
-        }
-        todo["tagger"] = 1;
-        todo["morpho"] = 1;
-        todo["lemmatizer_case"] = 1;
       }
-      else
-        newState = "parser";
+
+      todo["tagger"] = 1;
+      todo["morpho"] = 1;
+      todo["lemmatizer_case"] = 1;
     }
     else
       newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
@@ -807,19 +851,21 @@ void Oracle::createDatabase()
     {
       if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
       {
-        newState = "tokenizer";
-        movement = 1;
-        if (!c.getTape("ID").getHyp(1).empty())
-          newState = "tagger";
-        if (c.endOfTapes() || c.rawInputHeadIndex >= (int)c.rawInput.size())
-        {
-          newState = "parser";
-          movement = 0;
-        }
+        newState = "segmenter";
+        movement = 0;
       }
       else
         newState = "parser";
     }
+    else if (previousState == "segmenter")
+    {
+      newState = "tokenizer";
+      movement = 1;
+      if (!c.getTape("ID").getHyp(1).empty())
+        newState = "tagger";
+      if (c.endOfTapes())
+        movement = 0;
+    }
     else
       newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
 
-- 
GitLab