From 0daae795e17c19af16fe083778e22a0b1ef9ddc2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 14 Apr 2020 15:03:23 +0200
Subject: [PATCH] do not close dict during example extraction

---
 decoder/src/Decoder.cpp                    |  1 +
 reading_machine/include/ReadingMachine.hpp |  2 ++
 reading_machine/src/ReadingMachine.cpp     | 13 +++++++++++--
 trainer/src/MacaonTrain.cpp                |  1 -
 trainer/src/Trainer.cpp                    |  4 ++++
 5 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 1c16521..9b6b3a6 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
 {
   torch::AutoGradMode useGrad(false);
   machine.trainMode(false);
+  machine.setDictsState(Dict::State::Closed);
   machine.getStrategy().reset();
   config.addPredicted(machine.getPredicted());
 
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index dcece9a..9eb09d0 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -47,8 +47,10 @@ class ReadingMachine
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
+  void setDictsState(Dict::State state);
   void saveBest() const;
   void saveLast() const;
+  void saveDicts() const;
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 9bd3e2f..38f79c8 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier()
   return classifier.get();
 }
 
-void ReadingMachine::save(const std::string & modelNameTemplate) const
+void ReadingMachine::saveDicts() const
 {
   for (auto & it : dicts)
   {
@@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const
 
     std::fclose(file);
   }
+}
+
+void ReadingMachine::save(const std::string & modelNameTemplate) const
+{
+  saveDicts();
 
   auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
   torch::save(classifier->getNN(), pathToClassifier);
@@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
 void ReadingMachine::trainMode(bool isTrainMode)
 {
   classifier->getNN()->train(isTrainMode);
+}
+
+void ReadingMachine::setDictsState(Dict::State state)
+{
   for (auto & it : dicts)
-    it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
+    it.second.setState(state);
 }
 
 std::map<std::string, Dict> & ReadingMachine::getDicts()
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 280643a..08f0684 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -124,7 +124,6 @@ int MacaonTrain::main()
 
   fillDicts(machine, goldConfig);
 
-
   Trainer trainer(machine, batchSize);
   Decoder decoder(machine);
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 5d3b2ec..efe4c2b 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 {
   torch::AutoGradMode useGrad(false);
   machine.trainMode(false);
+  machine.setDictsState(Dict::State::Open);
 
   int maxNbExamplesPerFile = 250000;
   int currentExampleIndex = 0;
@@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
   std::fclose(f);
 
+  machine.saveDicts();
+
   fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
 }
 
@@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
 
   torch::AutoGradMode useGrad(train);
   machine.trainMode(train);
+  machine.setDictsState(Dict::State::Closed);
 
   auto lossFct = torch::nn::CrossEntropyLoss();
 
-- 
GitLab