From e45d45e60d81ad9f55df4a837b8f294779f37f44 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 1 Apr 2020 22:27:37 +0200
Subject: [PATCH] Checkpoints are created after each training epoch and it is
 possible to resume a training by training again on the same directory

---
 reading_machine/include/ReadingMachine.hpp |  5 ++-
 reading_machine/src/ReadingMachine.cpp     | 27 ++++++++++++---
 trainer/include/Trainer.hpp                |  2 ++
 trainer/src/MacaonTrain.cpp                | 38 ++++++++++++++++++++--
 trainer/src/Trainer.cpp                    | 10 ++++++
 5 files changed, 74 insertions(+), 8 deletions(-)

diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 4ce25aa..cc56ec2 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -13,6 +13,7 @@ class ReadingMachine
 
   static inline const std::string defaultMachineFilename = "machine.rm";
   static inline const std::string defaultModelFilename = "{}.pt";
+  static inline const std::string lastModelFilename = "{}.last";
   static inline const std::string defaultDictFilename = "{}.dict";
   static inline const std::string defaultDictName = "_default_";
 
@@ -28,6 +29,7 @@ class ReadingMachine
   private :
 
   void readFromFile(std::filesystem::path path);
+  void save(const std::string & modelNameTemplate) const;
 
   public :
 
@@ -38,10 +40,11 @@ class ReadingMachine
   Dict & getDict(const std::string & state);
   std::map<std::string, Dict> & getDicts();
   Classifier * getClassifier();
-  void save() const;
   bool isPredicted(const std::string & columnName) const;
   const std::set<std::string> & getPredicted() const;
   void trainMode(bool isTrainMode);
+  void saveBest() const;
+  void saveLast() const;
 };
 
 #endif
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index a74491b..50ce655 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -3,9 +3,18 @@
 
 ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
 {
-  dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
-
   readFromFile(path);
+
+  auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
+  auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
+  if (!lastSavedModel.empty())
+    torch::load(classifier->getNN(), lastSavedModel[0]);
+
+  for (auto path : savedDicts)
+    this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
+
+  if (dicts.count(defaultDictName) == 0)
+    dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
 }
 
 ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
@@ -98,7 +107,7 @@ Classifier * ReadingMachine::getClassifier()
   return classifier.get();
 }
 
-void ReadingMachine::save() const
+void ReadingMachine::save(const std::string & modelNameTemplate) const
 {
   for (auto & it : dicts)
   {
@@ -112,10 +121,20 @@ void ReadingMachine::save() const
     std::fclose(file);
   }
 
-  auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
+  auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
   torch::save(classifier->getNN(), pathToClassifier);
 }
 
+void ReadingMachine::saveBest() const
+{
+  save(defaultModelFilename);
+}
+
+void ReadingMachine::saveLast() const
+{
+  save(lastModelFilename);
+}
+
 bool ReadingMachine::isPredicted(const std::string & columnName) const
 {
   return predicted.count(columnName);
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 259a150..b5a548c 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -34,6 +34,8 @@ class Trainer
   void createDevDataset(SubConfig & goldConfig, bool debug);
   float epoch(bool printAdvancement);
   float evalOnDev(bool printAdvancement);
+  void loadOptimizer(std::filesystem::path path);
+  void saveOptimizer(std::filesystem::path path);
 };
 
 #endif
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 5f5db36..7b8e60f 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -117,7 +117,33 @@ int MacaonTrain::main()
 
   float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
 
-  for (int i = 0; i < nbEpoch; i++)
+  auto trainInfos = machinePath.parent_path() / "train.info";
+
+  int currentEpoch = 0;
+
+  if (std::filesystem::exists(trainInfos))
+  {
+    std::FILE * f = std::fopen(trainInfos.c_str(), "r");
+    char buffer[1024];
+    while (!std::feof(f))
+    {
+      if (buffer != std::fgets(buffer, 1024, f))
+        break;
+      float devScoreMean = std::stof(util::split(buffer, '\t').back());
+      if (computeDevScore and devScoreMean > bestDevScore)
+        bestDevScore = devScoreMean;
+      if (!computeDevScore and devScoreMean < bestDevScore)
+        bestDevScore = devScoreMean;
+      currentEpoch++;
+    }
+    std::fclose(f);
+  }
+
+  auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
+  if (std::filesystem::exists(trainInfos))
+    trainer.loadOptimizer(optimizerCheckpoint);
+
+  for (; currentEpoch < nbEpoch; currentEpoch++)
   {
     float loss = trainer.epoch(printAdvancement);
     machine.getStrategy().reset();
@@ -157,11 +183,17 @@ int MacaonTrain::main()
     if (saved)
     {
       bestDevScore = devScoreMean;
-      machine.save();
+      machine.saveBest();
     }
+    machine.saveLast();
+    trainer.saveOptimizer(optimizerCheckpoint);
     if (!debug)
       fmt::print(stderr, "\r{:80}\r", "");
-    fmt::print(stderr, "[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", util::getTime(), fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
+    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
+    fmt::print(stderr, "{}\n", iterStr);
+    std::FILE * f = std::fopen(trainInfos.c_str(), "a");
+    fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
+    std::fclose(f);
   }
 
   }
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 501af8e..6ebf18e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -163,3 +163,13 @@ float Trainer::evalOnDev(bool printAdvancement)
   return processDataset(devDataLoader, false, printAdvancement);
 }
 
+void Trainer::loadOptimizer(std::filesystem::path path)
+{
+  torch::load(*optimizer, path);
+}
+
+void Trainer::saveOptimizer(std::filesystem::path path)
+{
+  torch::save(*optimizer, path);
+}
+
-- 
GitLab