Skip to content
Snippets Groups Projects
Commit 47cbed2e authored by Franck Dary's avatar Franck Dary
Browse files

Changed name of optimizer checkpoints to avoid confusion with models checkpoints

parent b4c4cd4c
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,16 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
classifier->getNN()->registerEmbeddings(maxDictSize);
classifier->getNN()->to(NeuralNetworkImpl::device);
torch::load(classifier->getNN(), models[0]);
if (models.size() > 1)
util::myThrow("having more than one model file is not supported");
try
{
torch::load(classifier->getNN(), models[0]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what()));
}
}
void ReadingMachine::readFromFile(std::filesystem::path path)
......
......@@ -168,7 +168,7 @@ int MacaonTrain::main()
trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval);
machine.getClassifier()->resetOptimizer();
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
if (std::filesystem::exists(trainInfos))
machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment