Skip to content
Snippets Groups Projects
Commit e45d45e6 authored by Franck Dary's avatar Franck Dary
Browse files

Checkpoints are created after each training epoch and it is possible to resume...

Checkpoints are created after each training epoch and it is possible to resume a training by training again on the same directory
parent 1cf6cf2f
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ class ReadingMachine
static inline const std::string defaultMachineFilename = "machine.rm";
static inline const std::string defaultModelFilename = "{}.pt";
static inline const std::string lastModelFilename = "{}.last";
static inline const std::string defaultDictFilename = "{}.dict";
static inline const std::string defaultDictName = "_default_";
......@@ -28,6 +29,7 @@ class ReadingMachine
private :
void readFromFile(std::filesystem::path path);
void save(const std::string & modelNameTemplate) const;
public :
......@@ -38,10 +40,11 @@ class ReadingMachine
Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
Classifier * getClassifier();
void save() const;
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
void saveBest() const;
void saveLast() const;
};
#endif
......@@ -3,9 +3,18 @@
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
readFromFile(path);
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
for (auto path : savedDicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
......@@ -98,7 +107,7 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get();
}
void ReadingMachine::save() const
void ReadingMachine::save(const std::string & modelNameTemplate) const
{
for (auto & it : dicts)
{
......@@ -112,10 +121,20 @@ void ReadingMachine::save() const
std::fclose(file);
}
auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier);
}
void ReadingMachine::saveBest() const
{
save(defaultModelFilename);
}
void ReadingMachine::saveLast() const
{
save(lastModelFilename);
}
bool ReadingMachine::isPredicted(const std::string & columnName) const
{
return predicted.count(columnName);
......
......@@ -34,6 +34,8 @@ class Trainer
void createDevDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
};
#endif
......@@ -117,7 +117,33 @@ int MacaonTrain::main()
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
for (int i = 0; i < nbEpoch; i++)
auto trainInfos = machinePath.parent_path() / "train.info";
int currentEpoch = 0;
if (std::filesystem::exists(trainInfos))
{
std::FILE * f = std::fopen(trainInfos.c_str(), "r");
char buffer[1024];
while (!std::feof(f))
{
if (buffer != std::fgets(buffer, 1024, f))
break;
float devScoreMean = std::stof(util::split(buffer, '\t').back());
if (computeDevScore and devScoreMean > bestDevScore)
bestDevScore = devScoreMean;
if (!computeDevScore and devScoreMean < bestDevScore)
bestDevScore = devScoreMean;
currentEpoch++;
}
std::fclose(f);
}
auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt";
if (std::filesystem::exists(trainInfos))
trainer.loadOptimizer(optimizerCheckpoint);
for (; currentEpoch < nbEpoch; currentEpoch++)
{
float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset();
......@@ -157,11 +183,17 @@ int MacaonTrain::main()
if (saved)
{
bestDevScore = devScoreMean;
machine.save();
machine.saveBest();
}
machine.saveLast();
trainer.saveOptimizer(optimizerCheckpoint);
if (!debug)
fmt::print(stderr, "\r{:80}\r", "");
fmt::print(stderr, "[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", util::getTime(), fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "{}\n", iterStr);
std::FILE * f = std::fopen(trainInfos.c_str(), "a");
fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
std::fclose(f);
}
}
......
......@@ -163,3 +163,13 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement);
}
void Trainer::loadOptimizer(std::filesystem::path path)
{
torch::load(*optimizer, path);
}
void Trainer::saveOptimizer(std::filesystem::path path)
{
torch::save(*optimizer, path);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment