Skip to content
Snippets Groups Projects
Commit 93b2c58c authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bug where embeddings were not loaded when training resumed

parent 29907cb5
No related branches found
No related tags found
No related merge requests found
...@@ -53,6 +53,7 @@ class ReadingMachine ...@@ -53,6 +53,7 @@ class ReadingMachine
void saveLast() const; void saveLast() const;
void saveDicts() const; void saveDicts() const;
bool dictsAreNew() const; bool dictsAreNew() const;
void loadLastSaved();
}; };
#endif #endif
...@@ -5,10 +5,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) ...@@ -5,10 +5,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{ {
readFromFile(path); readFromFile(path);
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, "")); auto savedDicts = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::defaultDictFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
for (auto path : savedDicts) for (auto path : savedDicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
...@@ -207,3 +204,10 @@ bool ReadingMachine::dictsAreNew() const ...@@ -207,3 +204,10 @@ bool ReadingMachine::dictsAreNew() const
return _dictsAreNew; return _dictsAreNew;
} }
void ReadingMachine::loadLastSaved()
{
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
}
...@@ -106,8 +106,6 @@ int MacaonTrain::main() ...@@ -106,8 +106,6 @@ int MacaonTrain::main()
ReadingMachine machine(machinePath.string()); ReadingMachine machine(machinePath.string());
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
...@@ -136,8 +134,11 @@ int MacaonTrain::main() ...@@ -136,8 +134,11 @@ int MacaonTrain::main()
for (auto & it : machine.getDicts()) for (auto & it : machine.getDicts())
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.loadLastSaved();
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info"; auto trainInfos = machinePath.parent_path() / "train.info";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment