From b9a3230cafb0b516a85c06e53b87a2d161f83c0f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 5 Apr 2020 19:05:20 +0200 Subject: [PATCH] Made splitwords transition set optional in definition of reading machine --- decoder/src/Decoder.cpp | 3 ++- reading_machine/include/ReadingMachine.hpp | 3 ++- reading_machine/src/ReadingMachine.cpp | 8 +++++++- trainer/src/Trainer.cpp | 3 ++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 3a2caef..78e6204 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (debug) config.printForDebug(stderr); - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto dictState = machine.getDict(config.getState()).getState(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index dd0098d..dcece9a 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -26,7 +26,7 @@ class ReadingMachine std::map<std::string, Dict> dicts; std::set<std::string> predicted; - std::unique_ptr<TransitionSet> splitWordTransitionSet; + std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; private : @@ -39,6 +39,7 @@ class ReadingMachine ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); + bool hasSplitWordTransitionSet() const; Strategy & getStrategy(); Dict & getDict(const std::string & state); std::map<std::string, Dict> & getDicts(); diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9d4a854..94de1cb 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path) --curLine; - util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm) + util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) { this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); + curLine++; }); if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) @@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet() return classifier->getTransitionSet(); } +bool ReadingMachine::hasSplitWordTransitionSet() const +{ + return splitWordTransitionSet.get() != nullptr; +} + TransitionSet & ReadingMachine::getSplitWordTransitionSet() { return *splitWordTransitionSet; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6459bb3..d69d74a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: if (debug) config.printForDebug(stderr); - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) -- GitLab