From 0e82071312a9c6984fb681e1a67aa4dd54cff495 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 24 Apr 2019 14:36:43 +0200
Subject: [PATCH] Prepared for classifier of error detection

---
 decoder/src/macaon_decode.cpp         |  2 +-
 trainer/src/macaon_train.cpp          |  2 +-
 transition_machine/src/ActionBank.cpp | 20 ++++++++++++++++----
 transition_machine/src/Oracle.cpp     |  3 +++
 4 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp
index b005aa8..f6e9b94 100644
--- a/decoder/src/macaon_decode.cpp
+++ b/decoder/src/macaon_decode.cpp
@@ -51,7 +51,7 @@ po::options_description getOptionsDescription()
       "For each state of the Config, show its feature representation")
     ("readSize", po::value<int>()->default_value(0),
       "The number of lines of input that will be read and stored in memory at once.")
-    ("dictCapacity", po::value<int>()->default_value(30000),
+    ("dictCapacity", po::value<int>()->default_value(50000),
       "The maximal size of each Dict (number of differents embeddings).")
     ("interactive", po::value<bool>()->default_value(true),
       "Is the shell interactive ? Display advancement informations")
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index aaa3f7c..7a9a2b8 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -79,7 +79,7 @@ po::options_description getOptionsDescription()
       "The value of the token that act as a delimiter for sequences")
     ("batchSize", po::value<int>()->default_value(50),
       "The size of each minibatch (in number of taining examples)")
-    ("dictCapacity", po::value<int>()->default_value(30000),
+    ("dictCapacity", po::value<int>()->default_value(50000),
       "The maximal size of each Dict (number of differents embeddings).")
     ("tapeToMask", po::value<std::string>()->default_value("FORM"),
       "The name of the Tape for which some of the elements will be masked.")
diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp
index 7a304b6..d85fc39 100644
--- a/transition_machine/src/ActionBank.cpp
+++ b/transition_machine/src/ActionBank.cpp
@@ -561,14 +561,26 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
       auto undo = [dist](Config &, Action::BasicAction &)
         {
         };
-      auto appliable = [dist](Config &, Action::BasicAction &)
+      auto appliable = [dist](Config & c, Action::BasicAction)
         {
+          std::string classifierName = c.pastActions.top().first;
+          int stateHistorySize = c.getStateHistory(classifierName).size();
+
+          if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
+            return false;
+
+          if (c.hashHistory.contains(c.computeHash()))
+            return false;
+
+          if (stateHistorySize <= dist)
+            return false;
+
           return true;
         };
-      Action::BasicAction basicAction =
-        {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+        Action::BasicAction basicAction =
+          {Action::BasicAction::Type::Write, "", apply, undo, appliable};
 
-      sequence.emplace_back(basicAction);
+        sequence.emplace_back(basicAction);
     }
     else
     {
diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp
index 805b62c..bc54f85 100644
--- a/transition_machine/src/Oracle.cpp
+++ b/transition_machine/src/Oracle.cpp
@@ -97,6 +97,7 @@ void Oracle::createDatabase()
   str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle * oracle)
   {
+    return;
     File file(oracle->filename, "r");
     FILE * fd = file.getDescriptor();
     char b1[1024];
@@ -155,6 +156,7 @@ void Oracle::createDatabase()
   str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle * oracle)
   {
+    return;
     File file(oracle->filename, "r");
     FILE * fd = file.getDescriptor();
     char b1[1024];
@@ -220,6 +222,7 @@ void Oracle::createDatabase()
   str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle(
   [](Oracle * oracle)
   {
+    return;
     File file(oracle->filename, "r");
     FILE * fd = file.getDescriptor();
     char b1[1024];
-- 
GitLab