From d17eb24e79841b460152a95c1a7ef506f7da19dc Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 20 Sep 2019 16:13:27 +0200
Subject: [PATCH] Improved tokeniztion

---
 decoder/src/Decoder.cpp               |  2 +
 trainer/src/Trainer.cpp               | 11 ++++-
 transition_machine/include/Config.hpp |  2 +
 transition_machine/src/ActionBank.cpp | 35 +++++++++++--
 transition_machine/src/Config.cpp     | 13 +++--
 transition_machine/src/Oracle.cpp     | 71 ++++++++++++++++++++++++---
 6 files changed, 113 insertions(+), 21 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index f9e5ef1..fec7ae7 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -183,6 +183,8 @@ void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weig
 
 void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config)
 {
+    if (ProgramParameters::debug)
+      fprintf(stderr, "Applying action=<%s>\n", actionName.c_str());
     Action * action = tm.getCurrentClassifier()->getAction(actionName);
     TransitionMachine::Transition * transition = tm.getTransition(actionName);
     action->setInfos(tm.getCurrentClassifier()->name);
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 988838f..dd9fb08 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -111,6 +111,9 @@ void Trainer::computeScoreOnDev()
            }
         }
 
+      if (pAction.empty())
+        break;
+
       if (ProgramParameters::devLoss)
       {
         float loss = tm.getCurrentClassifier()->getLoss(*devConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
@@ -255,11 +258,15 @@ void Trainer::doStepTrain()
           if (pAction == "")
             pAction = it.second.second;
   
-      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
+      auto zeroCostActions = tm.getCurrentClassifier()->getZeroCostActions(trainConfig);
+      if (!zeroCostActions.empty())
+        oAction = zeroCostActions[0];
     }
     else
     {
-      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
+      auto zeroCostActions = tm.getCurrentClassifier()->getZeroCostActions(trainConfig);
+      if (!zeroCostActions.empty())
+        oAction = zeroCostActions[0];
     }
   
     if (oAction.empty())
diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp
index 9d16f90..f988f2b 100644
--- a/transition_machine/include/Config.hpp
+++ b/transition_machine/include/Config.hpp
@@ -187,6 +187,8 @@ class Config
   int rawInputHead;
   /// @brief Index of the rawInputHead in term of bytes.
   int rawInputHeadIndex;
+  /// @brief Index of current word in the sentence, as in conll format.
+  int currentWordIndex;
 
   public :
 
diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp
index b6c43b6..26d167e 100644
--- a/transition_machine/src/ActionBank.cpp
+++ b/transition_machine/src/ActionBank.cpp
@@ -307,6 +307,17 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
   {
     sequence.emplace_back(checkNotEmpty("FORM", 0));
     sequence.emplace_back(increaseTapesIfNeeded(1));
+
+    auto apply = [](Config & c, Action::BasicAction &)
+      {simpleBufferWrite(c, "ID", std::to_string(c.currentWordIndex), 0);};
+    auto undo = [](Config & c, Action::BasicAction &)
+      {simpleBufferWrite(c, "ID", std::string(""), 0);};
+    auto appliable = [](Config & c, Action::BasicAction &)
+      {return simpleBufferWriteAppliable(c, "ID", 0);};
+    Action::BasicAction basicAction =
+      {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+
+    sequence.emplace_back(basicAction);
   }
   else if(std::string(b1) == "ADDCHARTOWORD")
   {
@@ -314,8 +325,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
       {addCharToBuffer(c, "FORM", 0);};
     auto undo = [](Config & c, Action::BasicAction &)
       {removeCharFromBuffer(c, "FORM", 0);};
-    auto appliable = [](Config & , Action::BasicAction &)
-      {return true;};
+    auto appliable = [](Config & c, Action::BasicAction &)
+      {return c.getTape("FORM").getHyp(0).size() <= 2000;};
     Action::BasicAction basicAction =
       {Action::BasicAction::Type::Write, "", apply, undo, appliable};
 
@@ -334,10 +345,24 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
 
     sequence.emplace_back(moveRawInputHead(nbSymbols));
 
-    sequence.emplace_back(increaseTapesIfNeeded(splited.size()-1));
+    sequence.emplace_back(increaseTapesIfNeeded(splited.size()));
 
-    for (unsigned int i = 1; i < splited.size(); i++)
-      sequence.emplace_back(bufferWrite("FORM", splited[i], i-1));
+    for (unsigned int i = 0; i < splited.size(); i++)
+    {
+      sequence.emplace_back(bufferWrite("FORM", splited[i], i));
+
+      int splitedSize = (int)splited.size();
+      auto apply = [i, splitedSize](Config & c, Action::BasicAction &)
+        {simpleBufferWrite(c, "ID", i == 0 ? std::to_string(c.currentWordIndex) + "-" + std::to_string(c.currentWordIndex+splitedSize-2) : std::to_string(c.currentWordIndex+i-1), i);};
+      auto undo = [i](Config & c, Action::BasicAction &)
+        {simpleBufferWrite(c, "ID", std::string(""), i);};
+      auto appliable = [i](Config & c, Action::BasicAction &)
+        {return simpleBufferWriteAppliable(c, "ID", i);};
+      Action::BasicAction basicAction =
+        {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+
+      sequence.emplace_back(basicAction);
+    }
   }
   else if(std::string(b1) == "MOVERAW")
   {
diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp
index 6cc328f..2871956 100644
--- a/transition_machine/src/Config.cpp
+++ b/transition_machine/src/Config.cpp
@@ -14,6 +14,7 @@ Config::Config(BD & bd, const std::string inputFilename) : bd(bd), hashHistory(H
   this->inputFilename = inputFilename;
   head = 0;
   rawInputHead = 0;
+  currentWordIndex = 1;
   rawInputHeadIndex = 0;
   inputAllRead = false;
   for(int i = 0; i < bd.getNbLines(); i++)
@@ -34,6 +35,7 @@ Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHisto
   this->tapes = other.tapes;
   this->totalEntropy = other.totalEntropy;
   this->rawInputHead = other.rawInputHead;
+  this->currentWordIndex = other.currentWordIndex;
   this->rawInputHeadIndex = other.rawInputHeadIndex;
   this->rawInput = other.rawInput;
 
@@ -250,17 +252,13 @@ void Config::printAsOutput(FILE * output, int dataIndex, int realIndex)
 
 void Config::moveHead(int mvt)
 {
-//  if (ProgramParameters::rawInput && head + mvt >= tapes[0].size())
-//    for (auto & tape : tapes)
-//    {
-//      tape.addToRef("");
-//      tape.addToHyp("");
-//    }
-
   if (head + mvt < tapes[0].size())
   {
     head += mvt;
 
+    if (hasTape("ID") && split(getTape("ID").getHyp(0), '-').size() <= 1)
+      currentWordIndex += mvt;
+
     for (auto & tape : tapes)
       tape.moveHead(mvt);
 
@@ -322,6 +320,7 @@ void Config::reset()
   head = 0;
   rawInputHead = 0;
   rawInputHeadIndex = 0;
+  currentWordIndex = 1;
 
   file.reset();
   while (tapes[0].size() < ProgramParameters::readSize*4 && !inputAllRead)
diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp
index f5bcc73..a20e397 100644
--- a/transition_machine/src/Oracle.cpp
+++ b/transition_machine/src/Oracle.cpp
@@ -225,8 +225,8 @@ void Oracle::createDatabase()
         if (splited[0][i] != c.rawInput[c.rawInputHeadIndex+i])
           return 1;
 
-      for (unsigned int i = 1; i < splited.size(); i++)
-        if (c.getTape("FORM").getRef(i-1) != splited[i])
+      for (unsigned int i = 0; i < splited.size(); i++)
+        if (c.getTape("FORM").getRef(i) != splited[i])
           return 1;
 
       return 0;
@@ -238,6 +238,9 @@ void Oracle::createDatabase()
 
     if (action == "ADDCHARTOWORD" && currentWordRef.size() > currentWordHyp.size())
     {
+      if (c.hasTape("ID") && split(c.getTape("ID").getRef(0), '-').size() > 1)
+        return 1;
+
       for (unsigned int i = 0; i < (currentWordRef.size()-currentWordHyp.size()); i++)
         if (currentWordRef[currentWordHyp.size()+i] != c.rawInput[c.rawInputHeadIndex+i])
           return 1;
@@ -343,6 +346,9 @@ void Oracle::createDatabase()
         newState = "signature";
       else
         newState = "tokenizer";
+
+      if (split(previousAction, ' ')[0] == "splitword")
+        movement = 1;
     }
     else if (previousState == "tagger" || previousState == "error_tagger")
     {
@@ -585,11 +591,13 @@ void Oracle::createDatabase()
     int head = c.getHead();
     int stackHead = c.stackEmpty() ? 0 : c.stackTop();
     int stackGov = 0;
+    bool stackNoGov = false;
     try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));}
-      catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
+      catch (std::exception &){stackNoGov = true;}
     int headGov = 0;
+    bool headNoGov = false;
     try {headGov = head + std::stoi(govs.getRef(0));}
-      catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
+      catch (std::exception &){headNoGov = true;}
     int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1;
     int sentenceEnd = c.getHead();
 
@@ -608,10 +616,14 @@ void Oracle::createDatabase()
 
     if (parts[0] == "SHIFT")
     {
+      if (headNoGov)
+        return 0;
+
       for (int i = sentenceStart; i <= sentenceEnd; i++)
       {
         if (!isNum(govs.getRef(i-head)))
         {
+          continue;
           fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.getRef(i-head).c_str());
           exit(1);
         }
@@ -657,6 +669,8 @@ void Oracle::createDatabase()
     }
     else if (parts[0] == "REDUCE")
     {
+      if (stackNoGov)
+        return 0;
       if (stackGov == 0)
         cost++;
 
@@ -664,7 +678,7 @@ void Oracle::createDatabase()
       {
         int otherGov = 0;
         try {otherGov = i + std::stoi(govs.getRef(i-head));}
-          catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
+          catch (std::exception &){continue;}
         if (otherGov == stackHead)
           cost++;
       }
@@ -673,6 +687,9 @@ void Oracle::createDatabase()
     }
     else if (parts[0] == "LEFT")
     {
+      if (stackNoGov)
+        return 0;
+
       if (stackGov == 0)
         cost++;
 
@@ -683,7 +700,7 @@ void Oracle::createDatabase()
       {
         int otherGov = 0;
         try {otherGov = i + std::stoi(govs.getRef(i-head));}
-          catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
+          catch (std::exception &){continue;}
         if (otherGov == stackHead || stackGov == i)
           cost++;
       }
@@ -695,6 +712,9 @@ void Oracle::createDatabase()
     }
     else if (parts[0] == "RIGHT")
     {
+      if (headNoGov)
+        return 0;
+
       for (int j = 0; j < c.stackSize(); j++)
       {
         auto s = c.stackGetElem(j);
@@ -704,7 +724,7 @@ void Oracle::createDatabase()
 
         int otherGov = 0;
         try {otherGov = s + std::stoi(govs.getRef(s-head));}
-          catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
+          catch (std::exception &){continue;}
         if (otherGov == head || headGov == s)
           cost++;
       }
@@ -810,6 +830,43 @@ void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string &
     fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str());
     return;
   }
+  else if (parts[0] == "IGNORECHAR")
+  {
+    if (!isUtf8Separator(c.rawInput.begin()+c.rawInputHeadIndex))
+    {
+      fprintf(stderr, "rawInputHead is pointing to non separator character <%c>(%d)\n", c.rawInput[c.rawInputHeadIndex], c.rawInput[c.rawInputHeadIndex]);
+      return;
+    }
+    else if (c.rawInputHeadIndex+1 > (int)c.rawInput.size())
+    {
+      fprintf(stderr, "rawInputHeadIndex=%d rawInputSize=%lu\n", c.rawInputHeadIndex, c.rawInput.size());
+      return;
+    }
+
+    fprintf(stderr, "cannot explain\n");
+    return;
+  }
+  else if (parts[0] == "ENDWORD")
+  {
+    if (c.getTape("FORM").getRef(0) != c.getTape("FORM").getHyp(0))
+    {
+      fprintf(stderr, "hyp <%s> and ref <%s> are different\n", c.getTape("FORM").getHyp(0).c_str(), c.getTape("FORM").getRef(0).c_str());
+      return;
+    }
+
+    fprintf(stderr, "cannot explain\n");
+    return;
+  }
+  else if (parts[0] == "ADDCHARTOWORD")
+  {
+    fprintf(stderr, "cannot explain\n");
+    return;
+  }
+  else if (parts[0] == "SPLITWORD")
+  {
+    fprintf(stderr, "cannot explain\n");
+    return;
+  }
 
   auto & labels = c.getTape("LABEL");
   auto & govs = c.getTape("GOV");
-- 
GitLab