From 93b2c58cf28a7095db669f5c51d015e82760b8e9 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 30 Apr 2020 12:52:12 +0200 Subject: [PATCH] Corrected bug where embeddings were not loaded when training resumed --- reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/src/ReadingMachine.cpp | 10 +++++++--- trainer/src/MacaonTrain.cpp | 5 +++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 8a51466..34f1745 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -53,6 +53,7 @@ class ReadingMachine void saveLast() const; void saveDicts() const; bool dictsAreNew() const; + void loadLastSaved(); }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 70d8123..ec05ad5 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -5,10 +5,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { readFromFile(path); - auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, "")); - if (!lastSavedModel.empty()) - torch::load(classifier->getNN(), lastSavedModel[0]); for (auto path : savedDicts) this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); @@ -207,3 +204,10 @@ bool ReadingMachine::dictsAreNew() const return _dictsAreNew; } +void ReadingMachine::loadLastSaved() +{ + auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); + if (!lastSavedModel.empty()) + torch::load(classifier->getNN(), lastSavedModel[0]); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 3bb18a2..86acd6b 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -106,8 +106,6 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string()); - fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); - BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); @@ -136,8 +134,11 @@ int MacaonTrain::main() for (auto & it : machine.getDicts()) maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); + machine.loadLastSaved(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); + fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); + float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; -- GitLab