From a8613b8b5784ff89966a68f5f9523c5a6c63d7e2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 11 May 2020 14:13:07 +0200 Subject: [PATCH] Improved model selection when using dynamic oracle --- trainer/src/MacaonTrain.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index efcab79..20d03eb 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -82,7 +82,6 @@ int MacaonTrain::main() auto nbEpoch = variables["nbEpochs"].as<int>(); auto batchSize = variables["batchSize"].as<int>(); auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>(); - bool saveEverything = dynamicOracleInterval > 0; auto rarityThreshold = variables["rarityThreshold"].as<float>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; @@ -145,9 +144,9 @@ int MacaonTrain::main() if (buffer != std::fgets(buffer, 1024, f)) break; float devScoreMean = std::stof(util::split(buffer, '\t').back()); - if (computeDevScore and devScoreMean > bestDevScore) + if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval)) bestDevScore = devScoreMean; - if (!computeDevScore and devScoreMean < bestDevScore) + if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval)) bestDevScore = devScoreMean; currentEpoch++; } @@ -204,7 +203,7 @@ int MacaonTrain::main() if (!computeDevScore) saved = devScoreMean <= bestDevScore; - if (saveEverything) + if (currentEpoch == dynamicOracleInterval) saved = true; if (saved) { -- GitLab