From d145be52a13ddda6a522338a2edc08ab777ea579 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 25 Feb 2020 23:09:43 +0100
Subject: [PATCH] Reseting strategy between different corpuses

---
 reading_machine/include/Strategy.hpp |  2 ++
 reading_machine/src/Strategy.cpp     | 12 ++++++++++--
 trainer/src/macaon_train.cpp         |  6 ++++--
 3 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp
index 46946ab..9c4edd5 100644
--- a/reading_machine/include/Strategy.hpp
+++ b/reading_machine/include/Strategy.hpp
@@ -21,6 +21,7 @@ class Strategy
   std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
   std::map<std::string, bool> isDone;
   std::vector<std::string> defaultCycle;
+  std::vector<std::string> originalDefaultCycle;
   std::string initialState{"UNDEFINED"};
 
   private :
@@ -33,6 +34,7 @@ class Strategy
   Strategy(const std::vector<std::string_view> & lines);
   std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
   const std::string getInitialState() const;
+  void reset();
 };
 
 #endif
diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp
index ee1efa6..a13ac5e 100644
--- a/reading_machine/src/Strategy.cpp
+++ b/reading_machine/src/Strategy.cpp
@@ -41,6 +41,7 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
     util::myThrow("Strategy is empty");
   defaultCycle.pop_back();
   std::reverse(defaultCycle.begin(), defaultCycle.end());
+  originalDefaultCycle = defaultCycle;
 }
 
 std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
@@ -96,7 +97,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
   auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
 
   std::string target;
-  int movement;
+  int movement = -1;
 
   if (foundSpecific != edges.end())
   {
@@ -113,7 +114,7 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
     util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
 
   if (!isDone[target])
-    return {target, target == defaultCycle.back() ? movement : 0};
+    return {target, c.canMoveWordIndex(movement) ? movement : 0};
 
   if (defaultCycle.empty())
     return endMovement;
@@ -126,3 +127,10 @@ const std::string Strategy::getInitialState() const
   return initialState;
 }
 
+void Strategy::reset()
+{
+  for (auto & it : isDone)
+    it.second = false;
+  defaultCycle = originalDefaultCycle;
+}
+
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 936d16c..34a009a 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -92,12 +92,14 @@ int main(int argc, char * argv[])
   for (int i = 0; i < nbEpoch; i++)
   {
     float loss = trainer.epoch(!debug);
+    machine.getStrategy().reset();
     auto devConfig = devGoldConfig;
     if (debug)
       fmt::print(stderr, "Decoding dev :\n");
     else
       fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
     decoder.decode(devConfig, 1, debug);
+    machine.getStrategy().reset();
     decoder.evaluate(devConfig, modelPath, devTsvFile);
     std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());
     std::string devScoresStr = "";
@@ -117,9 +119,9 @@ int main(int argc, char * argv[])
       machine.save();
     }
     if (debug)
-      fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
+      fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
     else
-      fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
+      fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
   }
 
   }
-- 
GitLab