From aa79c9cfbaf6aa9ba7cf43cd4c1ac64baeda0109 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 14 Jun 2020 18:19:04 +0200
Subject: [PATCH] close dicts when extracting dev dataset

---
 trainer/src/MacaonTrain.cpp | 8 ++++++++
 trainer/src/Trainer.cpp     | 1 -
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 63a32de..9227e72 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -207,9 +207,13 @@ int MacaonTrain::main()
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold))
     {
+      machine.setDictsState(Dict::State::Open);
       trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
       if (!computeDevScore)
+      {
+        machine.setDictsState(Dict::State::Closed);
         trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+      }
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
     {
@@ -225,9 +229,13 @@ int MacaonTrain::main()
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
     {
+      machine.setDictsState(Dict::State::Open);
       trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
       if (!computeDevScore)
+      {
+        machine.setDictsState(Dict::State::Closed);
         trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
+      }
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
     {
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 66efbfe..ff4f1a3 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -22,7 +22,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
   SubConfig config(goldConfig, goldConfig.getNbLines());
 
   machine.trainMode(false);
-  machine.setDictsState(Dict::State::Open);
 
   extractExamples(config, debug, dir, epoch, dynamicOracle);
 
-- 
GitLab