diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 1c1652131ea095506eaa99fedd042ef0c58e2b5a..9b6b3a67033013989b31b560dab167de3fcc08eb 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.setDictsState(Dict::State::Closed); machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index dcece9ae4680a93e316de3919c88884c1fb84f9b..9eb09d038a853625dcbb0b649f02556a06eea94c 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -47,8 +47,10 @@ class ReadingMachine bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); + void setDictsState(Dict::State state); void saveBest() const; void saveLast() const; + void saveDicts() const; }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9bd3e2f22c5b779f040d696f01eeacdb7fcc71d3..38f79c84c0a1e14e7adffe168486d47d1674a944 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -134,7 +134,7 @@ Classifier * ReadingMachine::getClassifier() return classifier.get(); } -void ReadingMachine::save(const std::string & modelNameTemplate) const +void ReadingMachine::saveDicts() const { for (auto & it : dicts) { @@ -147,6 +147,11 @@ void ReadingMachine::save(const std::string & modelNameTemplate) const std::fclose(file); } +} + +void ReadingMachine::save(const std::string & modelNameTemplate) const +{ + saveDicts(); auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); torch::save(classifier->getNN(), pathToClassifier); @@ -175,8 +180,12 @@ const std::set<std::string> & ReadingMachine::getPredicted() const void ReadingMachine::trainMode(bool isTrainMode) { classifier->getNN()->train(isTrainMode); +} + +void ReadingMachine::setDictsState(Dict::State state) +{ for (auto & it : dicts) - it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); + it.second.setState(state); } std::map<std::string, Dict> & ReadingMachine::getDicts() diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 280643a47afcc19d2930c1531022d144c96df4cc..08f06843b6ea8f79b5e8cf302a42107bb186a95e 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -124,7 +124,6 @@ int MacaonTrain::main() fillDicts(machine, goldConfig); - Trainer trainer(machine, batchSize); Decoder decoder(machine); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 5d3b2ecd4a442c4276bb504da7c5adac96a31133..efe4c2b4a20abd3036b33bfb65a0763fcdba4e6d 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -44,6 +44,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000; int currentExampleIndex = 0; @@ -163,6 +164,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); + machine.saveDicts(); + fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); } @@ -176,6 +179,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance torch::AutoGradMode useGrad(train); machine.trainMode(train); + machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss();