From e45d45e60d81ad9f55df4a837b8f294779f37f44 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 1 Apr 2020 22:27:37 +0200 Subject: [PATCH] Checkpoints are created after each training epoch and it is possible to resume a training by training again on the same directory --- reading_machine/include/ReadingMachine.hpp | 5 ++- reading_machine/src/ReadingMachine.cpp | 27 ++++++++++++--- trainer/include/Trainer.hpp | 2 ++ trainer/src/MacaonTrain.cpp | 38 ++++++++++++++++++++-- trainer/src/Trainer.cpp | 10 ++++++ 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 4ce25aa..cc56ec2 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -13,6 +13,7 @@ class ReadingMachine static inline const std::string defaultMachineFilename = "machine.rm"; static inline const std::string defaultModelFilename = "{}.pt"; + static inline const std::string lastModelFilename = "{}.last"; static inline const std::string defaultDictFilename = "{}.dict"; static inline const std::string defaultDictName = "_default_"; @@ -28,6 +29,7 @@ class ReadingMachine private : void readFromFile(std::filesystem::path path); + void save(const std::string & modelNameTemplate) const; public : @@ -38,10 +40,11 @@ class ReadingMachine Dict & getDict(const std::string & state); std::map<std::string, Dict> & getDicts(); Classifier * getClassifier(); - void save() const; bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); + void saveBest() const; + void saveLast() const; }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index a74491b..50ce655 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -3,9 +3,18 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { - dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); - readFromFile(path); + + auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); + auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, "")); + if (!lastSavedModel.empty()) + torch::load(classifier->getNN(), lastSavedModel[0]); + + for (auto path : savedDicts) + this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); + + if (dicts.count(defaultDictName) == 0) + dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); } ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) @@ -98,7 +107,7 @@ Classifier * ReadingMachine::getClassifier() return classifier.get(); } -void ReadingMachine::save() const +void ReadingMachine::save(const std::string & modelNameTemplate) const { for (auto & it : dicts) { @@ -112,10 +121,20 @@ void ReadingMachine::save() const std::fclose(file); } - auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName()); + auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); torch::save(classifier->getNN(), pathToClassifier); } +void ReadingMachine::saveBest() const +{ + save(defaultModelFilename); +} + +void ReadingMachine::saveLast() const +{ + save(lastModelFilename); +} + bool ReadingMachine::isPredicted(const std::string & columnName) const { return predicted.count(columnName); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 259a150..b5a548c 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -34,6 +34,8 @@ class Trainer void createDevDataset(SubConfig & goldConfig, bool debug); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); + void loadOptimizer(std::filesystem::path path); + void saveOptimizer(std::filesystem::path path); }; #endif diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 5f5db36..7b8e60f 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -117,7 +117,33 @@ int MacaonTrain::main() float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); - for (int i = 0; i < nbEpoch; i++) + auto trainInfos = machinePath.parent_path() / "train.info"; + + int currentEpoch = 0; + + if (std::filesystem::exists(trainInfos)) + { + std::FILE * f = std::fopen(trainInfos.c_str(), "r"); + char buffer[1024]; + while (!std::feof(f)) + { + if (buffer != std::fgets(buffer, 1024, f)) + break; + float devScoreMean = std::stof(util::split(buffer, '\t').back()); + if (computeDevScore and devScoreMean > bestDevScore) + bestDevScore = devScoreMean; + if (!computeDevScore and devScoreMean < bestDevScore) + bestDevScore = devScoreMean; + currentEpoch++; + } + std::fclose(f); + } + + auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; + if (std::filesystem::exists(trainInfos)) + trainer.loadOptimizer(optimizerCheckpoint); + + for (; currentEpoch < nbEpoch; currentEpoch++) { float loss = trainer.epoch(printAdvancement); machine.getStrategy().reset(); @@ -157,11 +183,17 @@ int MacaonTrain::main() if (saved) { bestDevScore = devScoreMean; - machine.save(); + machine.saveBest(); } + machine.saveLast(); + trainer.saveOptimizer(optimizerCheckpoint); if (!debug) fmt::print(stderr, "\r{:80}\r", ""); - fmt::print(stderr, "[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", util::getTime(), fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + fmt::print(stderr, "{}\n", iterStr); + std::FILE * f = std::fopen(trainInfos.c_str(), "a"); + fmt::print(f, "{}\t{}\n", iterStr, devScoreMean); + std::fclose(f); } } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 501af8e..6ebf18e 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -163,3 +163,13 @@ float Trainer::evalOnDev(bool printAdvancement) return processDataset(devDataLoader, false, printAdvancement); } +void Trainer::loadOptimizer(std::filesystem::path path) +{ + torch::load(*optimizer, path); +} + +void Trainer::saveOptimizer(std::filesystem::path path) +{ + torch::save(*optimizer, path); +} + -- GitLab