diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index ec05ad56819e8c8be1de133c86bfbd0622c7e7eb..2900e4dd4ff662775352a88f7add59d54c6dd8b4 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -30,7 +30,16 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file classifier->getNN()->registerEmbeddings(maxDictSize); classifier->getNN()->to(NeuralNetworkImpl::device); - torch::load(classifier->getNN(), models[0]); + if (models.size() > 1) + util::myThrow("having more than one model file is not supported"); + + try + { + torch::load(classifier->getNN(), models[0]); + } catch (std::exception & e) + { + util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what())); + } } void ReadingMachine::readFromFile(std::filesystem::path path) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 6832e3b6c33f310b33d915e758200bbdcc623bcc..7d8a52baafbd2a8858317822e6b2bec755aae9ef 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -168,7 +168,7 @@ int MacaonTrain::main() trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); machine.getClassifier()->resetOptimizer(); - auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; + auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; if (std::filesystem::exists(trainInfos)) machine.getClassifier()->loadOptimizer(optimizerCheckpoint);