diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index ccc385288547de519d0170df54d8ec646824f260..4e3ded1fc8b7517234c247c6daadd5de3b196790 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -85,7 +85,6 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std { torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice()); getNN()->registerEmbeddings(); - getNN()->to(NeuralNetworkImpl::getDevice()); } else if (std::filesystem::exists(getLastFilename())) { @@ -93,10 +92,8 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std resetOptimizer(); loadOptimizer(); } - else - { - getNN()->to(NeuralNetworkImpl::getDevice()); - } + + getNN()->to(NeuralNetworkImpl::getDevice()); } int Classifier::getNbParameters() const @@ -181,8 +178,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, void Classifier::loadOptimizer() { auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name)); - if (std::filesystem::exists(optimizerPath)) - torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice()); + torch::load(*optimizer, optimizerPath); } void Classifier::saveOptimizer()