From a8613b8b5784ff89966a68f5f9523c5a6c63d7e2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 11 May 2020 14:13:07 +0200
Subject: [PATCH] Improved model selection when using dynamic oracle

---
 trainer/src/MacaonTrain.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index efcab79..20d03eb 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -82,7 +82,6 @@ int MacaonTrain::main()
   auto nbEpoch = variables["nbEpochs"].as<int>();
   auto batchSize = variables["batchSize"].as<int>();
   auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
-  bool saveEverything = dynamicOracleInterval > 0;
   auto rarityThreshold = variables["rarityThreshold"].as<float>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
@@ -145,9 +144,9 @@ int MacaonTrain::main()
       if (buffer != std::fgets(buffer, 1024, f))
         break;
       float devScoreMean = std::stof(util::split(buffer, '\t').back());
-      if (computeDevScore and devScoreMean > bestDevScore)
+      if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval))
         bestDevScore = devScoreMean;
-      if (!computeDevScore and devScoreMean < bestDevScore)
+      if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval))
         bestDevScore = devScoreMean;
       currentEpoch++;
     }
@@ -204,7 +203,7 @@ int MacaonTrain::main()
     if (!computeDevScore)
       saved = devScoreMean <= bestDevScore;
 
-    if (saveEverything)
+    if (currentEpoch == dynamicOracleInterval)
       saved = true;
     if (saved)
     {
-- 
GitLab