From 93b2c58cf28a7095db669f5c51d015e82760b8e9 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 30 Apr 2020 12:52:12 +0200
Subject: [PATCH] Corrected bug where embeddings were not loaded when training
 resumed

---
 reading_machine/include/ReadingMachine.hpp |  1 +
 reading_machine/src/ReadingMachine.cpp     | 10 +++++++---
 trainer/src/MacaonTrain.cpp                |  5 +++--
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 8a51466..34f1745 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -53,6 +53,7 @@ class ReadingMachine
   void saveLast() const;
   void saveDicts() const;
   bool dictsAreNew() const;
+  void loadLastSaved();
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 70d8123..ec05ad5 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -5,10 +5,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
 {
   readFromFile(path);
 
-  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
   auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
-  if (!lastSavedModel.empty())
-    torch::load(classifier->getNN(), lastSavedModel[0]);
 
   for (auto path : savedDicts)
     this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
@@ -207,3 +204,10 @@ bool ReadingMachine::dictsAreNew() const
   return _dictsAreNew;
 }
 
+void ReadingMachine::loadLastSaved()
+{
+  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
+  if (!lastSavedModel.empty())
+    torch::load(classifier->getNN(), lastSavedModel[0]);
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 3bb18a2..86acd6b 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -106,8 +106,6 @@ int MacaonTrain::main()
 
   ReadingMachine machine(machinePath.string());
 
-  fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
-
   BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
   BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
 
@@ -136,8 +134,11 @@ int MacaonTrain::main()
   for (auto & it : machine.getDicts())
     maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
   machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
+  machine.loadLastSaved();
   machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
 
+  fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
+
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
   auto trainInfos = machinePath.parent_path() / "train.info";
-- 
GitLab