From 006cdfc9073d55656ff50a89561340b03477bb18 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 30 Jul 2019 15:32:15 +0200
Subject: [PATCH] trying to fix backtrack prediction

---
 neural_network/src/MLPBase.cpp |  2 +-
 trainer/src/Trainer.cpp        | 26 +++++++++++++++++++++-----
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/neural_network/src/MLPBase.cpp b/neural_network/src/MLPBase.cpp
index 00ce27c..18585c0 100644
--- a/neural_network/src/MLPBase.cpp
+++ b/neural_network/src/MLPBase.cpp
@@ -162,7 +162,7 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, const std::vector<f
   dynet::Expression batchedLoss;
   std::vector<dynet::Expression> goldExpressions;
   for (auto & gold : goldsContinuous)
-    goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({1,(unsigned int)gold.size()}), gold));
+    goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({(unsigned int)gold.size()}), gold));
  
   dynet::Expression batchedGold = dynet::concatenate_to_batch(goldExpressions);
   batchedLoss = dynet::sum_batches(dynet::squared_distance(output, batchedGold));
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 61032a3..38c6a35 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -400,20 +400,36 @@ void Trainer::doStepTrain()
 
       if (TI.lastActionWasPredicted[normalStateName])
       {
+        std::string updateInfos;
+
         if (newCost >= lastCost)
         {
-          loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
+//          loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON"));
+          int nbActions = tm.getCurrentClassifier()->getNbActions();
+          int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
+          float value = 1.0 / (nbActions-1);
+          std::vector<float> goldOutput(nbActions, value);
+          goldOutput[backIndex] = 0.0;
+
+          loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
+
+          updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, bad decision";
         }
         else
         {
-          loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
+//loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
+          int nbActions = tm.getCurrentClassifier()->getNbActions();
+          int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
+          std::vector<float> goldOutput(nbActions, 0.0);
+          goldOutput[backIndex] = 1.0;
+
+          loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
 
-          if (ProgramParameters::debug)
-            fprintf(stderr, "Updating neural network \'%s\', gold=\'%s\'\n", tm.getCurrentClassifier()->name.c_str(), trainConfig.getCurrentStateHistory().top().c_str());
+          updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, good decision";
         }
 
         if (ProgramParameters::debug)
-          fprintf(stderr, "Updating neural network \'%s\'\n", tm.getCurrentClassifier()->name.c_str());
+          fprintf(stderr, "Updating neural network \'%s\' : %s\n", tm.getCurrentClassifier()->name.c_str(), updateInfos.c_str());
 
         TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
       }
-- 
GitLab