From 47cbed2eaefdd062bb77b554e548ecec81f28066 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 30 Apr 2020 14:09:04 +0200 Subject: [PATCH] Changed name of optimizer checkpoints to avoid confusion with models checkpoints --- reading_machine/src/ReadingMachine.cpp | 11 ++++++++++- trainer/src/MacaonTrain.cpp | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index ec05ad5..2900e4d 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -30,7 +30,16 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file classifier->getNN()->registerEmbeddings(maxDictSize); classifier->getNN()->to(NeuralNetworkImpl::device); - torch::load(classifier->getNN(), models[0]); + if (models.size() > 1) + util::myThrow("having more than one model file is not supported"); + + try + { + torch::load(classifier->getNN(), models[0]); + } catch (std::exception & e) + { + util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what())); + } } void ReadingMachine::readFromFile(std::filesystem::path path) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 6832e3b..7d8a52b 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -168,7 +168,7 @@ int MacaonTrain::main() trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); machine.getClassifier()->resetOptimizer(); - auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; + auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; if (std::filesystem::exists(trainInfos)) machine.getClassifier()->loadOptimizer(optimizerCheckpoint); -- GitLab