From 29907cb503c988bd8b6867cc239d6d58e0ac1515 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 29 Apr 2020 23:25:11 +0200 Subject: [PATCH] Corrected a bug where dict was modified if training was resumed --- reading_machine/include/ReadingMachine.hpp | 2 ++ reading_machine/src/ReadingMachine.cpp | 8 ++++++ trainer/src/MacaonTrain.cpp | 29 +++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 9eb09d0..8a51466 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -25,6 +25,7 @@ class ReadingMachine std::unique_ptr<Strategy> strategy; std::map<std::string, Dict> dicts; std::set<std::string> predicted; + bool _dictsAreNew{false}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; @@ -51,6 +52,7 @@ class ReadingMachine void saveBest() const; void saveLast() const; void saveDicts() const; + bool dictsAreNew() const; }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index bbb05d4..70d8123 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -14,7 +14,10 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); if (dicts.count(defaultDictName) == 0) + { + _dictsAreNew = true; dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); + } } ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) @@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts() return dicts; } +bool ReadingMachine::dictsAreNew() const +{ + return _dictsAreNew; +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 601b176..3bb18a2 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -114,24 +114,29 @@ int MacaonTrain::main() Trainer trainer(machine, batchSize); Decoder decoder(machine); - trainer.fillDicts(goldConfig); - std::size_t maxDictSize = 0; - for (auto & it : machine.getDicts()) + if (machine.dictsAreNew()) { - std::size_t originalSize = it.second.size(); - for (;;) + trainer.fillDicts(goldConfig); + for (auto & it : machine.getDicts()) { - std::size_t lastSize = it.second.size(); - it.second.removeRareElements(); - float decrease = 100.0*(originalSize-it.second.size())/originalSize; - if (decrease >= rarityThreshold or lastSize == it.second.size()) - break; + std::size_t originalSize = it.second.size(); + for (;;) + { + std::size_t lastSize = it.second.size(); + it.second.removeRareElements(); + float decrease = 100.0*(originalSize-it.second.size())/originalSize; + if (decrease >= rarityThreshold or lastSize == it.second.size()) + break; + } } - maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); + machine.saveDicts(); } + + std::size_t maxDictSize = 0; + for (auto & it : machine.getDicts()) + maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); - machine.saveDicts(); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); -- GitLab