From aa79c9cfbaf6aa9ba7cf43cd4c1ac64baeda0109 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 14 Jun 2020 18:19:04 +0200 Subject: [PATCH] close dicts when extracting dev dataset --- trainer/src/MacaonTrain.cpp | 8 ++++++++ trainer/src/Trainer.cpp | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 63a32de..9227e72 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -207,9 +207,13 @@ int MacaonTrain::main() } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold)) { + machine.setDictsState(Dict::State::Open); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); if (!computeDevScore) + { + machine.setDictsState(Dict::State::Closed); trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + } } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) { @@ -225,9 +229,13 @@ int MacaonTrain::main() } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { + machine.setDictsState(Dict::State::Open); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); if (!computeDevScore) + { + machine.setDictsState(Dict::State::Closed); trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + } } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save)) { diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 66efbfe..ff4f1a3 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -22,7 +22,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - machine.setDictsState(Dict::State::Open); extractExamples(config, debug, dir, epoch, dynamicOracle); -- GitLab