From 8d4d9ab5c05f0c4261cb3c61190027a7bd157445 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 18 Jun 2020 22:03:53 +0200
Subject: [PATCH] Make sure nothing is added to the dict mid training

---
 reading_machine/src/ReadingMachine.cpp | 1 +
 torch_modules/src/Submodule.cpp        | 2 ++
 trainer/src/MacaonTrain.cpp            | 3 ++-
 trainer/src/Trainer.cpp                | 1 -
 4 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index e1b8169..f5fd3c4 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -11,6 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
   readFromFile(path);
 
   loadDicts();
+  trainMode(false);
   classifier->getNN()->registerEmbeddings();
   classifier->getNN()->to(NeuralNetworkImpl::device);
 
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index ea63b99..66f6455 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -9,6 +9,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
 {
   if (path.empty())
     return;
+  if (!is_training())
+    return;
 
   if (!std::filesystem::exists(path))
     util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 21529bf..8e43cd3 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -207,7 +207,7 @@ int MacaonTrain::main()
     }
     if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
     {
-      machine.setDictsState(Dict::State::Open);
+      machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
       trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
       if (!computeDevScore)
       {
@@ -220,6 +220,7 @@ int MacaonTrain::main()
       if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
       {
         machine.resetClassifier();
+        machine.trainMode(currentEpoch == 0);
         machine.getClassifier()->getNN()->registerEmbeddings();
         machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
         fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 1d363c5..af1ef2e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -155,7 +155,6 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
 
   torch::AutoGradMode useGrad(train);
   machine.trainMode(train);
-  machine.setDictsState(Dict::State::Closed);
 
   auto lossFct = torch::nn::CrossEntropyLoss();
 
-- 
GitLab