From 9e8e91bedaaa60236264b52e24082ff4141bd51f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 29 Apr 2019 22:56:17 +0200
Subject: [PATCH] Probably fixed error prediction

---
 trainer/src/Trainer.cpp               | 4 +++-
 transition_machine/include/Action.hpp | 3 ++-
 transition_machine/include/Config.hpp | 2 +-
 transition_machine/src/Action.cpp     | 8 +++++---
 transition_machine/src/ActionBank.cpp | 2 +-
 5 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a2c8476..13de243 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -347,6 +347,8 @@ void Trainer::doStepTrain()
     }
 
     actionName = pAction;
+    if (choiceWithProbability(0.5))
+      actionName = oAction;
 
     char buffer[1024];
     if (sscanf(tm.getCurrentClassifier()->name.c_str(), "Error_%s", buffer) != 1)
@@ -361,7 +363,7 @@ void Trainer::doStepTrain()
     auto & normalHistory = trainConfig.getActionsHistory(normalClassifierName);
 
     // If a BACK just happened
-    if (normalHistory.size() > 1 && errorHistory.size() > 0 && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
+    if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && trainConfig.getCurrentStateHistory().top() != "EPSILON" && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
     {
       auto & lastAction = normalHistory[normalHistory.size()-2];
       auto & newAction = normalHistory[normalHistory.size()-1];
diff --git a/transition_machine/include/Action.hpp b/transition_machine/include/Action.hpp
index 120212a..e49e383 100644
--- a/transition_machine/include/Action.hpp
+++ b/transition_machine/include/Action.hpp
@@ -29,7 +29,8 @@ class Action
     {
       Push,
       Pop,
-      Write
+      Write,
+      Back
     };
 
     /// @brief The type of this BasicAction.
diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp
index 5078172..9aba457 100644
--- a/transition_machine/include/Config.hpp
+++ b/transition_machine/include/Config.hpp
@@ -134,7 +134,7 @@ class Config
 
   private :
 
-  const unsigned int HISTORY_SIZE = 1000;
+  const unsigned int HISTORY_SIZE = 100000;
   /// @brief The name of the current state of the TransitionMachine.
   std::string currentStateName;
   /// @brief For each state of the TransitionMachine, an history of the Action that have been applied to this Config.
diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp
index fabde9d..4039013 100644
--- a/transition_machine/src/Action.cpp
+++ b/transition_machine/src/Action.cpp
@@ -10,7 +10,7 @@ void Action::apply(Config & config)
   for(auto & basicAction : sequence)
     basicAction.apply(config, basicAction);
 
-  config.getCurrentStateHistory().push(namePrefix);
+  config.getCurrentStateHistory().push(name);
   config.pastActions.push(std::pair<std::string, Action>(config.getCurrentStateName(), *this));
 
   config.moveHead(headMovement);
@@ -48,7 +48,7 @@ void Action::undoOnlyStack(Config & config)
   for(int i = sequence.size()-1; i >= 0; i--)
   {
     auto type = sequence[i].type;
-    if(type == BasicAction::Type::Write)
+    if(type == BasicAction::Type::Write || type == BasicAction::Type::Back)
       continue;
 
     sequence[i].undo(config, sequence[i]);
@@ -57,7 +57,9 @@ void Action::undoOnlyStack(Config & config)
   if (ProgramParameters::debug)
     fprintf(stderr, "Undoing only stack action <%s><%s>, state history size = %d past actions size = %d...", stateName.c_str(), name.c_str(), config.getStateHistory(stateName).size(), config.pastActions.size());
 
-  config.getStateHistory(stateName).pop();
+  char buffer[1024];
+  if (sscanf(stateName.c_str(), "error_%s", buffer) != 1)
+    config.getStateHistory(stateName).pop();
 
   if (ProgramParameters::debug)
     fprintf(stderr, "done\n");
diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp
index f7b1cbf..4224c1e 100644
--- a/transition_machine/src/ActionBank.cpp
+++ b/transition_machine/src/ActionBank.cpp
@@ -578,7 +578,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
           return true;
         };
         Action::BasicAction basicAction =
-          {Action::BasicAction::Type::Write, "", apply, undo, appliable};
+          {Action::BasicAction::Type::Back, name, apply, undo, appliable};
 
         sequence.emplace_back(basicAction);
     }
-- 
GitLab