From 1795a1b32b95179ed2ad6f40eccc4505ad351a32 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 21 Mar 2021 12:54:35 +0100
Subject: [PATCH] When loading optimizer, do not silently fail

---
 reading_machine/src/Classifier.cpp | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index ccc3852..4e3ded1 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -85,7 +85,6 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   {
     torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
     getNN()->registerEmbeddings();
-    getNN()->to(NeuralNetworkImpl::getDevice());
   }
   else if (std::filesystem::exists(getLastFilename()))
   {
@@ -93,10 +92,8 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
     resetOptimizer();
     loadOptimizer();
   }
-  else
-  {
-    getNN()->to(NeuralNetworkImpl::getDevice());
-  }
+
+  getNN()->to(NeuralNetworkImpl::getDevice());
 }
 
 int Classifier::getNbParameters() const
@@ -181,8 +178,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition,
 void Classifier::loadOptimizer()
 {
   auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
-  if (std::filesystem::exists(optimizerPath))
-    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice());
+  torch::load(*optimizer, optimizerPath);
 }
 
 void Classifier::saveOptimizer()
-- 
GitLab