From eaeb5040d8f044a825e4368359c1570dcfd82e1f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 18 Sep 2019 18:04:53 +0200
Subject: [PATCH] Added tokenizer for training

---
 maca_common/include/util.hpp              |   2 +
 maca_common/src/util.cpp                  |  30 ++++++
 trainer/src/TrainInfos.cpp                |   4 +
 trainer/src/Trainer.cpp                   |  25 +----
 transition_machine/include/ActionBank.hpp |  18 ++++
 transition_machine/include/Config.hpp     |  12 +++
 transition_machine/src/ActionBank.cpp     | 106 ++++++++++++++++++++++
 transition_machine/src/Config.cpp         |  76 ++++++++++++++--
 transition_machine/src/FeatureBank.cpp    |  40 ++++++++
 transition_machine/src/Oracle.cpp         |  71 ++++++++++++++-
 10 files changed, 354 insertions(+), 30 deletions(-)

diff --git a/maca_common/include/util.hpp b/maca_common/include/util.hpp
index 6770d97..7833af1 100644
--- a/maca_common/include/util.hpp
+++ b/maca_common/include/util.hpp
@@ -204,7 +204,9 @@ float getRandomValueInRange(int range);
 int getNbLines(const std::string & filename);
 
 int getStartIndexOfNthSymbol(const std::string & s, int n);
+int getStartIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n);
 int getEndIndexOfNthSymbol(const std::string & s, int n);
+int getEndIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n);
 unsigned int getNbSymbols(const std::string & s);
 std::string shrinkString(const std::string & base, int maxSize, const std::string token);
 
diff --git a/maca_common/src/util.cpp b/maca_common/src/util.cpp
index ea26bd6..8682941 100644
--- a/maca_common/src/util.cpp
+++ b/maca_common/src/util.cpp
@@ -451,6 +451,26 @@ int getStartIndexOfNthSymbol(const std::string & s, int n)
   return it - s.begin();
 }
 
+int getStartIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n)
+{
+  if (n >= 0)
+  {
+    auto it = s;
+    for (int i = 0; i < n; i++)
+      try {utf8::next(it, end);}
+      catch (utf8::not_enough_room &) {return -1;}
+
+    return it - s;
+  }
+
+  auto it = s;
+  for (int i = 0; i < -n; i++)
+    try {utf8::prior(it, end);}
+    catch (utf8::not_enough_room &) {return 1;}
+  
+  return it - s;
+}
+
 int getEndIndexOfNthSymbol(const std::string & s, int n)
 {
   auto it = s.begin();
@@ -461,6 +481,16 @@ int getEndIndexOfNthSymbol(const std::string & s, int n)
   return (it-1) - s.begin();
 }
 
+int getEndIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n)
+{
+  auto it = s;
+  for (int i = 0; i < n+1; i++)
+    try {utf8::next(it, end);}
+    catch (utf8::not_enough_room &) {return i == n ? end - s - 1 : -1;}
+
+  return (it-1) - s;
+}
+
 unsigned int getNbSymbols(const std::string & s)
 {
   return utf8::distance(s.begin(), s.end());
diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp
index aa84078..8677127 100644
--- a/trainer/src/TrainInfos.cpp
+++ b/trainer/src/TrainInfos.cpp
@@ -161,6 +161,8 @@ void TrainInfos::computeTrainScores(Config & c)
       addTrainScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
     else if (it.first == "Tagger")
       addTrainScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead()));
+    else if (it.first == "Tokenizer")
+      addTrainScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead()));
     else if (it.first == "Morpho")
       addTrainScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead()));
     else if (it.first == "Lemmatizer_Rules")
@@ -183,6 +185,8 @@ void TrainInfos::computeDevScores(Config & c)
       addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
     else if (it.first == "Parser")
       addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead()));
+    else if (it.first == "Tokenizer")
+      addDevScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead()));
     else if (it.first == "Tagger")
       addDevScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead()));
     else if (it.first == "Morpho")
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 56a2a49..988838f 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -239,7 +239,6 @@ void Trainer::doStepTrain()
 
   std::string pAction = "";
   std::string oAction = "";
-  bool pActionIsZeroCost = false;
 
   std::string actionName = "";
   float loss = 0.0;
@@ -253,19 +252,10 @@ void Trainer::doStepTrain()
   
       for (auto & it : weightedActions)
         if (it.first)
-        {
           if (pAction == "")
             pAction = it.second.second;
   
-          if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0)
-          {
-            oAction = it.second.second;
-            break;
-          }
-        }
-  
-      if (pAction == oAction)
-        pActionIsZeroCost = true;
+      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
     }
     else
     {
@@ -311,17 +301,8 @@ void Trainer::doStepTrain()
     }
     else
     {
-      if (pActionIsZeroCost)
-      {
-        actionName = pAction;
-        TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
-      }
-      else
-      {
-        actionName = oAction;
-        TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = false;
-      }
-
+      actionName = oAction;
+      TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = oAction == pAction;
     }
   
     if (ProgramParameters::debug)
diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp
index bb44dc6..8fe5817 100644
--- a/transition_machine/include/ActionBank.hpp
+++ b/transition_machine/include/ActionBank.hpp
@@ -80,6 +80,10 @@ class ActionBank
   /// @param relativeIndex The index of the column that will be read and written into, relatively to the head of the Config.
   static void writeRuleResult(Config & config, const std::string & fromTapeName, const std::string & targetTapeName, const std::string & rule, int relativeIndex);
 
+  static void addCharToBuffer(Config & config, const std::string & tapeName, int relativeIndex);
+
+  static void removeCharFromBuffer(Config & config, const std::string & tapeName, int relativeIndex);
+
   /// \brief Write something on the buffer
   ///
   /// \param tapeName The tape we will write to
@@ -96,6 +100,20 @@ class ActionBank
   /// \return A BasicAction moving the head
   static Action::BasicAction moveHead(int movement);
 
+  /// \brief Move the raw input head
+  ///
+  /// \param movement The relative movement of the raw input head
+  ///
+  /// \return A BasicAction moving the head
+  static Action::BasicAction moveRawInputHead(int movement);
+
+  /// \brief Verify if rawInput begins with word
+  ///
+  /// \param word The word to verify
+  ///
+  /// \return A BasicAction only appliable if word is the prefix of rawInput.
+  static Action::BasicAction rawInputBeginsWith(std::string word);
+
   /// \brief Write something on the buffer
   ///
   /// \param tapeName The tape we will write to
diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp
index fc7eefc..9d16f90 100644
--- a/transition_machine/include/Config.hpp
+++ b/transition_machine/include/Config.hpp
@@ -181,6 +181,12 @@ class Config
   LimitedStack< std::pair<std::string, Action> > pastActions;
   /// @brief The last action that have been undone.
   std::pair<std::string, int> lastUndoneAction;
+  /// @brief The input before tokenization.
+  std::string rawInput;
+  /// @brief Head of the raw input.
+  int rawInputHead;
+  /// @brief Index of the rawInputHead in term of bytes.
+  int rawInputHeadIndex;
 
   public :
 
@@ -231,6 +237,10 @@ class Config
   ///
   /// @param mvt The relative increment in the position of the head.
   void moveHead(int mvt);
+  /// @brief Move the rawInputHead of this Config.
+  ///
+  /// @param mvt The relative increment in the position of the rawInputHead.
+  void moveRawInputHead(int mvt);
   /// @brief Whether or not this Config is terminal.
   ///
   /// A Config is terminal when the head is at the end of the multi-tapes buffer and the stack is empty.
@@ -340,6 +350,8 @@ class Config
   ///
   /// @return True if the head is at the end of the tapes.
   bool endOfTapes() const;
+  /// @brief Update rawInput according to the tape TEXT.
+  void updateRawInput();
   /// @brief Set the output file.
   void setOutputFile(FILE * outputFile);
   /// @brief Print the cells that have not been printed.
diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp
index 8d886a8..22bdde4 100644
--- a/transition_machine/src/ActionBank.cpp
+++ b/transition_machine/src/ActionBank.cpp
@@ -17,6 +17,43 @@ Action::BasicAction ActionBank::moveHead(int movement)
   return basicAction;
 }
 
+Action::BasicAction ActionBank::moveRawInputHead(int movement)
+{
+  auto apply = [movement](Config & c, Action::BasicAction &)
+    {c.moveRawInputHead(movement);};
+  auto undo = [movement](Config & c, Action::BasicAction &)
+    {c.moveRawInputHead(-movement);};
+  auto appliable = [movement](Config & c, Action::BasicAction &)
+    {return c.rawInputHeadIndex+movement < (int)c.rawInput.size();};
+  Action::BasicAction basicAction =
+    {Action::BasicAction::Type::MoveHead, "", apply, undo, appliable};
+
+  return basicAction;
+}
+
+Action::BasicAction ActionBank::rawInputBeginsWith(std::string word)
+{
+  auto apply = [](Config &, Action::BasicAction &)
+    {};
+  auto undo = [](Config &, Action::BasicAction &)
+    {};
+  auto appliable = [word](Config & c, Action::BasicAction &)
+    {
+      if (c.rawInputHeadIndex+word.size() >= c.rawInput.size()) 
+        return false;
+
+      for (unsigned int i = 0; i < word.size(); i++)
+        if (c.rawInput[c.rawInputHeadIndex+i] != word[i])
+          return false;
+
+      return true;
+    };
+  Action::BasicAction basicAction =
+    {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+
+  return basicAction;
+}
+
 Action::BasicAction ActionBank::bufferWrite(std::string tapeName, std::string value, int relativeIndex)
 {
   auto apply = [tapeName, value, relativeIndex](Config & c, Action::BasicAction &)
@@ -192,6 +229,50 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
 
     sequence.emplace_back(moveHead(movement));
   }
+  else if(std::string(b1) == "IGNORECHAR")
+  {
+    sequence.emplace_back(moveRawInputHead(1));
+  }
+  else if(std::string(b1) == "ENDWORD")
+  {
+  }
+  else if(std::string(b1) == "ADDCHARTOWORD")
+  {
+    auto apply = [](Config & c, Action::BasicAction &)
+      {addCharToBuffer(c, "FORM", 0);};
+    auto undo = [](Config & c, Action::BasicAction &)
+      {removeCharFromBuffer(c, "FORM", 0);};
+    auto appliable = [](Config & , Action::BasicAction &)
+      {return true;};
+    Action::BasicAction basicAction =
+      {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+
+    sequence.emplace_back(basicAction);
+    sequence.emplace_back(moveRawInputHead(1));
+  }
+  else if(std::string(b1) == "SPLITWORD")
+  {
+    if (sscanf(name.c_str(), "SPLITWORD %s", b2) != 1)
+      invalidNameAndAbort(ERRINFO);
+
+    auto splited = split(b2, '@');
+    int nbSymbols = getNbSymbols(splited[0]);
+
+    sequence.emplace_back(rawInputBeginsWith(splited[0]));
+
+    sequence.emplace_back(moveRawInputHead(nbSymbols));
+
+    for (unsigned int i = 1; i < splited.size(); i++)
+      sequence.emplace_back(bufferWrite("FORM", splited[i], i-1));
+  }
+  else if(std::string(b1) == "MOVERAW")
+  {
+    int movement;
+    if (sscanf(name.c_str(), "MOVERAW %d", &movement) != 1)
+      invalidNameAndAbort(ERRINFO);
+
+    sequence.emplace_back(moveRawInputHead(movement));
+  }
   else if(std::string(b1) == "ERROR")
   {
     auto apply = [](Config &, Action::BasicAction &)
@@ -675,6 +756,31 @@ void ActionBank::writeRuleResult(Config & config, const std::string & fromTapeNa
   toTape.setHyp(relativeIndex, applyRule(from, rule));
 }
 
+void ActionBank::addCharToBuffer(Config & config, const std::string & tapeName, int relativeIndex)
+{
+  auto & tape = config.getTape(tapeName);
+  auto & from = tape.getHyp(relativeIndex);
+
+  int nbChar = getEndIndexOfNthSymbolFrom(config.rawInput.begin()+config.rawInputHeadIndex,config.rawInput.end(), 0)+1;
+
+  std::string suffix = std::string(config.rawInput.begin()+config.rawInputHeadIndex, config.rawInput.begin()+config.rawInputHeadIndex+nbChar);
+
+  tape.setHyp(relativeIndex, from+suffix);
+}
+
+void ActionBank::removeCharFromBuffer(Config & config, const std::string & tapeName, int relativeIndex)
+{
+  auto & tape = config.getTape(tapeName);
+  auto from = tape.getRef(relativeIndex);
+
+  std::string suffix = std::string(config.rawInput.begin()+config.rawInputHeadIndex, config.rawInput.begin()+config.rawInputHeadIndex+getEndIndexOfNthSymbolFrom(config.rawInput.begin()+config.rawInputHeadIndex,config.rawInput.end(), 0));
+
+  for (char c : suffix)
+    from.pop_back();
+
+  tape.setHyp(relativeIndex, from);
+}
+
 int ActionBank::getLinkLength(const Config & c, const std::string & action)
 {
   auto splitted = split(action, ' ');
diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp
index 10fe33b..69e37ea 100644
--- a/transition_machine/src/Config.cpp
+++ b/transition_machine/src/Config.cpp
@@ -10,9 +10,11 @@ Config::Config(BD & bd, const std::string inputFilename) : bd(bd), hashHistory(H
 {
   this->outputFile = nullptr;
   this->stackHistory = -1;
-  this->inputFilename = inputFilename;
   this->lastIndexPrinted = -1;
+  this->inputFilename = inputFilename;
   head = 0;
+  rawInputHead = 0;
+  rawInputHeadIndex = 0;
   inputAllRead = false;
   for(int i = 0; i < bd.getNbLines(); i++)
     tapes.emplace_back(bd.getNameOfLine(i), bd.lineIsKnown(i));
@@ -31,6 +33,9 @@ Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHisto
   this->lastIndexPrinted = other.lastIndexPrinted;
   this->tapes = other.tapes;
   this->totalEntropy = other.totalEntropy;
+  this->rawInputHead = other.rawInputHead;
+  this->rawInputHeadIndex = other.rawInputHeadIndex;
+  this->rawInput = other.rawInput;
 
   this->inputFilename = other.inputFilename;
   this->inputAllRead = other.inputAllRead;
@@ -137,6 +142,9 @@ void Config::readInput()
       tape.addToHyp("");
     }
   }
+
+  if (hasTape("TEXT"))
+    updateRawInput();
 }
 
 void Config::printForDebug(FILE * output)
@@ -174,6 +182,21 @@ void Config::printForDebug(FILE * output)
   for(int i = 0; i < 80; i++)
     fprintf(output, "-%s", i == 80-1 ? "\n" : "");
 
+  if (!rawInput.empty())
+  {
+    int rawWindow = 30;
+    int relativeHeadIndex = getEndIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.end(), rawWindow);
+    auto endIter = rawInput.begin() + rawInputHeadIndex + relativeHeadIndex + 1;
+    if (relativeHeadIndex < 0)
+      endIter = rawInput.end();
+
+    std::string toPrint(rawInput.begin()+rawInputHeadIndex, endIter);
+    fprintf(stderr, "%s\n", toPrint.c_str());
+
+    for(int i = 0; i < 80; i++)
+      fprintf(output, "-%s", i == 80-1 ? "\n" : "");
+  }
+
   printColumns(output, cols, 3);
 
   fprintf(output, "Stack : ");
@@ -227,6 +250,28 @@ void Config::moveHead(int mvt)
   }
 }
 
+void Config::moveRawInputHead(int mvt)
+{
+  if (mvt >= 0)
+  {
+    int relativeIndexMvt = getStartIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.end(), mvt);
+    if (relativeIndexMvt > 0)
+    {
+      rawInputHead += mvt;
+      rawInputHeadIndex += relativeIndexMvt;
+    }
+  }
+  else
+  {
+    int relativeIndexMvt = getStartIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.begin(), mvt);
+    if (relativeIndexMvt < 0)
+    {
+      rawInputHeadIndex += relativeIndexMvt;
+      rawInputHead += mvt;
+    }
+  }
+}
+
 bool Config::isFinal()
 {
   return endOfTapes() && stack.empty();
@@ -248,6 +293,8 @@ void Config::reset()
 
   inputAllRead = false;
   head = 0;
+  rawInputHead = 0;
+  rawInputHeadIndex = 0;
 
   file.reset();
   while (tapes[0].size() < ProgramParameters::readSize*4 && !inputAllRead)
@@ -327,16 +374,17 @@ LimitedStack<float> & Config::getCurrentStateEntropyHistory()
 
 void Config::shuffle(const std::string & delimiterTape, const std::string & delimiter)
 {
-  std::vector< std::pair<unsigned int, unsigned int> > delimiters;
+  struct Trio{unsigned int a; unsigned int b; unsigned int c; Trio(unsigned int a, unsigned int b, unsigned int c): a(a), b(b), c(c){}};
+  std::vector<Trio> delimiters;
 
   if (delimiterTape == "0")
   {
     unsigned int previousIndex = 0;
     for (int i = 0; i < tapes[0].refSize(); i++)
     {
-      delimiters.emplace_back(previousIndex, i);
+      delimiters.emplace_back(previousIndex, i, delimiters.size());
       previousIndex = i+1;
-    }   
+    }
   }
   else
   {
@@ -345,7 +393,7 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli
     for (int i = 0; i < tape.refSize(); i++)
       if (tape.getRef(i-head) == delimiter)
       {
-        delimiters.emplace_back(previousIndex, i);
+        delimiters.emplace_back(previousIndex, i, delimiters.size());
         previousIndex = i+1;
       }
   }
@@ -356,7 +404,7 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli
     return;
   }
 
-  std::pair<unsigned int, unsigned int> suffix = {delimiters.back().second+1, tapes[0].refSize()-1};
+  std::pair<unsigned int, unsigned int> suffix = {delimiters.back().b+1, tapes[0].refSize()-1};
 
   std::random_shuffle(delimiters.begin(), delimiters.end());
 
@@ -367,13 +415,16 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli
     newTapes[tape].clearDataForCopy();
 
     for (auto & delimiter : delimiters)
-      newTapes[tape].copyPart(tapes[tape], delimiter.first, delimiter.second+1);
+      newTapes[tape].copyPart(tapes[tape], delimiter.a, delimiter.b+1);
 
     if (suffix.first <= suffix.second)
       newTapes[tape].copyPart(tapes[tape], suffix.first, suffix.second+1);
   }
 
   tapes = newTapes;
+
+  if (!rawInput.empty())
+    updateRawInput();
 }
 
 int Config::stackGetElem(int index) const
@@ -658,3 +709,14 @@ float Config::Tape::getScore(int from, int to)
   return 100.0*res / (1+to-from);
 }
 
+void Config::updateRawInput()
+{
+  rawInput = "";
+  auto & textTape = getTape("TEXT");
+  for (int i = 0; i < textTape.size(); i++)
+  {
+    if (textTape[i] != "_")
+      rawInput += (rawInput.empty() ? std::string("") : std::string(" ")) + textTape[i];
+  }
+}
+
diff --git a/transition_machine/src/FeatureBank.cpp b/transition_machine/src/FeatureBank.cpp
index ab40482..ccef787 100644
--- a/transition_machine/src/FeatureBank.cpp
+++ b/transition_machine/src/FeatureBank.cpp
@@ -212,6 +212,46 @@ FeatureModel::FeatureValue getDistance(int index1, int index2, const std::string
 
 std::function<FeatureModel::FeatureValue(Config &)> FeatureBank::str2func(const std::string & s)
 {
+  if (split(s,'.')[0] == "raw")
+  {
+    int relativeIndex;
+    try {relativeIndex = std::stoi(split(s, '.')[1]);}
+    catch (std::exception &)
+    {
+      fprintf(stderr, "ERROR (%s) : invalid feature format \'%s\'. Relative index must be an integer. Aborting.\n", ERRINFO, s.c_str());
+      exit(1);
+    }
+      return [relativeIndex, s](Config & c)
+      {
+        int relativeCharIndex = getStartIndexOfNthSymbolFrom(c.rawInput.begin()+c.rawInputHeadIndex, relativeIndex >= 0 ? c.rawInput.end() : c.rawInput.begin(), relativeIndex);
+
+        Dict * dict = Dict::getDict("letters");
+        auto policy = dictPolicy2FeaturePolicy(dict->policy);
+
+        if (relativeCharIndex >= 0 && relativeIndex < 0)
+          return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy});
+        if (relativeCharIndex < 0 && relativeIndex >= 0)
+          return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy});
+
+        int endIndex = getEndIndexOfNthSymbolFrom(c.rawInput.begin()+c.rawInputHeadIndex+relativeCharIndex, c.rawInput.end(), 0);
+
+        auto a = c.rawInput.begin()+c.rawInputHeadIndex+relativeCharIndex;
+        auto b = a + endIndex + 1;
+
+        std::string value;
+
+        if (a <= b)
+          value = std::string(a,b);
+        else
+          value = std::string(b,a);
+
+        if (value.empty())
+          return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy});
+
+        return FeatureModel::FeatureValue({dict, s, value, policy});
+      };
+  }
+
   auto splited = split(s, '#');
 
   if (splited.size() == 1)
diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp
index aeaf7e0..769684e 100644
--- a/transition_machine/src/Oracle.cpp
+++ b/transition_machine/src/Oracle.cpp
@@ -211,7 +211,40 @@ void Oracle::createDatabase()
   },
   [](Config & c, Oracle *, const std::string & action)
   {
-    return (action == "WRITE b.0 BIO " + c.getTape("BIO").getRef(0) || c.endOfTapes()) ? 0 : 1;
+    auto & currentWordRef = c.getTape("FORM").getRef(0);
+    auto & currentWordHyp = c.getTape("FORM").getHyp(0);
+
+    auto splited = split(split(action, ' ').back(),'@');
+
+    if (splited.size() > 2)
+    {
+      if (c.rawInput.begin() + splited[0].size() >= c.rawInput.end())
+        return 1;
+
+      for (unsigned int i = 0; i < splited[0].size(); i++)
+        if (splited[0][i] != c.rawInput[c.rawInputHeadIndex+i])
+          return 1;
+
+      for (unsigned int i = 1; i < splited.size(); i++)
+        if (c.getTape("FORM").getRef(i-1) != splited[i])
+          return 1;
+
+      return 0;
+    }
+
+    if (currentWordRef == currentWordHyp)
+      if (action == "ENDWORD")
+        return 0;
+
+    if (action == "ADDCHARTOWORD" && currentWordRef.size() > currentWordHyp.size())
+    {
+      for (unsigned int i = 0; i < (currentWordRef.size()-currentWordHyp.size()); i++)
+        if (currentWordRef[currentWordHyp.size()+i] != c.rawInput[c.rawInputHeadIndex+i])
+          return 1;
+      return 0;
+    }
+
+    return 1;
   })));
 
   str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
@@ -288,6 +321,42 @@ void Oracle::createDatabase()
     return 0;
   })));
 
+  str2oracle.emplace("strategy_tokenizer,tagger", std::unique_ptr<Oracle>(new Oracle(
+  [](Oracle *)
+  {
+  },
+  [](Config & c, Oracle *)
+  {
+    if (c.pastActions.size() == 0)
+      return std::string("MOVE tokenizer 0");
+
+    std::string previousState = noAccentLower(c.pastActions.getElem(0).first);
+    std::string previousAction = noAccentLower(c.pastActions.getElem(0).second.name);
+    std::string newState;
+    int movement = 0;
+
+    if (previousState == "signature")
+      newState = "tagger";
+    else if (previousState == "tokenizer")
+    {
+      if (split(previousAction, ' ')[0] == "splitword" || split(previousAction, ' ')[0] == "endword")
+        newState = "signature";
+      else
+        newState = "tokenizer";
+    }
+    else if (previousState == "tagger" || previousState == "error_tagger")
+    {
+      newState = "tokenizer";
+      movement = 1;
+    }
+
+    return "MOVE " + newState + " " + std::to_string(movement);
+  },
+  [](Config &, Oracle *, const std::string &)
+  {
+    return 0;
+  })));
+
   str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle *)
   {
-- 
GitLab