From 0daae795e17c19af16fe083778e22a0b1ef9ddc2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 14 Apr 2020 15:03:23 +0200 Subject: [PATCH] do not close dict during example extraction --- decoder/src/Decoder.cpp | 1 + reading_machine/include/ReadingMachine.hpp | 2 ++ reading_machine/src/ReadingMachine.cpp | 13 +++++++++++-- trainer/src/MacaonTrain.cpp | 1 - trainer/src/Trainer.cpp | 4 ++++ 5 files changed, 18 insertions(+), 3 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 1c16521..9b6b3a6 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index dcece9a..9eb09d0 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -47,8 +47,10 @@ class ReadingMachine bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); + void setDictsState(Dict::State state); void saveBest() const; void saveLast() const; + void saveDicts() const; }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9bd3e2f..38f79c8 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier() return classifier.get(); } -void ReadingMachine::save(const std::string & modelNameTemplate) const +void ReadingMachine::saveDicts() const { for (auto & it : dicts) { @@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const std::fclose(file); } +} + +void ReadingMachine::save(const std::string & modelNameTemplate) const +{ + saveDicts(); auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); torch::save(classifier->getNN(), pathToClassifier); @@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const void ReadingMachine::trainMode(bool isTrainMode) { classifier->getNN()->train(isTrainMode); +} + +void ReadingMachine::setDictsState(Dict::State state) +{ for (auto & it : dicts) - it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); + it.second.setState(state); } std::map<std::string, Dict> & ReadingMachine::getDicts() diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 280643a..08f0684 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -124,7 +124,6 @@ int MacaonTrain::main() fillDicts(machine, goldConfig); - Trainer trainer(machine, batchSize); Decoder decoder(machine); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 5d3b2ec..efe4c2b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000; int currentExampleIndex = 0; @@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); + machine.saveDicts(); + fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); } @@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance torch::AutoGradMode useGrad(train); machine.trainMode(train); + machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss(); -- GitLab