From a5f8ec235af194380114da236ab2c765381474f0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 15 May 2019 13:50:17 +0200
Subject: [PATCH] Fixed backtrack prediction

---
 trainer/include/Trainer.hpp               | 12 +++++++-----
 trainer/src/Trainer.cpp                   | 10 +++++++---
 transition_machine/include/Classifier.hpp |  7 +++++++
 transition_machine/src/Classifier.cpp     |  5 +++++
 4 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 1294283..1f1171c 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -32,6 +32,11 @@ class Trainer
     }
   };
 
+  public :
+
+  /// @brief The FeatureDescritpion of a Config.
+  using FD = FeatureModel::FeatureDescription;
+
   private :
 
   /// @brief The TransitionMachine that will be trained.
@@ -63,11 +68,8 @@ class Trainer
   float currentSpeed;
   /// @brief The date the last time the speed has been computed.
   std::chrono::time_point<std::chrono::high_resolution_clock> pastTime;
-
-  public :
-
-  /// @brief The FeatureDescritpion of a Config.
-  using FD = FeatureModel::FeatureDescription;
+  /// @brief For each classifier, a FeatureDescription it needs to remember for a future update.
+  std::map<std::string,FD> pendingFD;
 
   private :
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index fd46e97..71ef4f1 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -249,7 +249,6 @@ void Trainer::doStepTrain()
   std::string oAction = "";
   bool pActionIsZeroCost = false;
 
-  
   std::string actionName = "";
   float loss = 0.0;
 
@@ -400,14 +399,14 @@ void Trainer::doStepTrain()
 
       if (newCost >= lastCost)
       {
-        loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex("EPSILON"));
+        loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
   
         if (pActionIsZeroCost)
           TI.addTrainSuccess(tm.getCurrentClassifier()->name);
       }
       else
       {
-        loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
+        loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
         TI.addTrainSuccess(tm.getCurrentClassifier()->name);
       }
 
@@ -421,6 +420,11 @@ void Trainer::doStepTrain()
       tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
       fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
     }
+
+    if (actionName != "EPSILON")
+    {
+      pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig);
+    }
   }
 
   Action * action = tm.getCurrentClassifier()->getAction(actionName);
diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp
index 0ece5ed..a432fef 100644
--- a/transition_machine/include/Classifier.hpp
+++ b/transition_machine/include/Classifier.hpp
@@ -156,6 +156,13 @@ class Classifier
   ///
   /// @return The loss.
   float trainOnExample(Config & config, int gold);
+  /// @brief Train the classifier on a training example.
+  ///
+  /// @param fd The FeatureDescription to work with.
+  /// @param gold The gold class of the FeatureDescription.
+  ///
+  /// @return The loss.
+  float trainOnExample(FeatureModel::FeatureDescription & fd, int gold);
   /// @brief Get the loss of the classifier on a training example.
   ///
   /// @param config The Config to work with.
diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp
index 6059fba..1eb8727 100644
--- a/transition_machine/src/Classifier.cpp
+++ b/transition_machine/src/Classifier.cpp
@@ -273,6 +273,11 @@ float Classifier::trainOnExample(Config & config, int gold)
   return nn->update(fd, gold);
 }
 
+float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold)
+{
+  return nn->update(fd, gold);
+}
+
 float Classifier::getLoss(Config & config, int gold)
 {
   auto & fd = fm->getFeatureDescription(config);
-- 
GitLab