From 47cbed2eaefdd062bb77b554e548ecec81f28066 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 30 Apr 2020 14:09:04 +0200
Subject: [PATCH] Changed name of optimizer checkpoints to avoid confusion with
 models checkpoints

---
 reading_machine/src/ReadingMachine.cpp | 11 ++++++++++-
 trainer/src/MacaonTrain.cpp            |  2 +-
 2 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index ec05ad5..2900e4d 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -30,7 +30,16 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
   classifier->getNN()->registerEmbeddings(maxDictSize);
   classifier->getNN()->to(NeuralNetworkImpl::device);
 
-  torch::load(classifier->getNN(), models[0]);
+  if (models.size() > 1)
+    util::myThrow("having more than one model file is not supported");
+
+  try
+  {
+    torch::load(classifier->getNN(), models[0]);
+  } catch (std::exception & e)
+  {
+    util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what()));
+  }
 }
 
 void ReadingMachine::readFromFile(std::filesystem::path path)
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 6832e3b..7d8a52b 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -168,7 +168,7 @@ int MacaonTrain::main()
     trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
 
   machine.getClassifier()->resetOptimizer();
-  auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
+  auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
   if (std::filesystem::exists(trainInfos))
     machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
 
-- 
GitLab