From 1795a1b32b95179ed2ad6f40eccc4505ad351a32 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 21 Mar 2021 12:54:35 +0100 Subject: [PATCH] When loading optimizer, do not silently fail --- reading_machine/src/Classifier.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index ccc3852..4e3ded1 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -85,7 +85,6 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std { torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice()); getNN()->registerEmbeddings(); - getNN()->to(NeuralNetworkImpl::getDevice()); } else if (std::filesystem::exists(getLastFilename())) { @@ -93,10 +92,8 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std resetOptimizer(); loadOptimizer(); } - else - { - getNN()->to(NeuralNetworkImpl::getDevice()); - } + + getNN()->to(NeuralNetworkImpl::getDevice()); } int Classifier::getNbParameters() const @@ -181,8 +178,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, void Classifier::loadOptimizer() { auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name)); - if (std::filesystem::exists(optimizerPath)) - torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice()); + torch::load(*optimizer, optimizerPath); } void Classifier::saveOptimizer() -- GitLab