diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 3a2caefd42fdd6601ecc5051aa72fa023cf47898..78e62048ab12864c1d8fb8b7afda396bee9ed619 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (debug) config.printForDebug(stderr); - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto dictState = machine.getDict(config.getState()).getState(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index dd0098d7d9115602957f61e34762f2c92e8191fc..dcece9ae4680a93e316de3919c88884c1fb84f9b 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -26,7 +26,7 @@ class ReadingMachine std::map<std::string, Dict> dicts; std::set<std::string> predicted; - std::unique_ptr<TransitionSet> splitWordTransitionSet; + std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; private : @@ -39,6 +39,7 @@ class ReadingMachine ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); TransitionSet & getTransitionSet(); TransitionSet & getSplitWordTransitionSet(); + bool hasSplitWordTransitionSet() const; Strategy & getStrategy(); Dict & getDict(const std::string & state); std::map<std::string, Dict> & getDicts(); diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9d4a854c19ee9f48c35e52e975214719e888c58d..94de1cbf43e15167fddad47b9ed6b043a0f08651 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path) --curLine; - util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm) + util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) { this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); + curLine++; }); if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) @@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet() return classifier->getTransitionSet(); } +bool ReadingMachine::hasSplitWordTransitionSet() const +{ + return splitWordTransitionSet.get() != nullptr; +} + TransitionSet & ReadingMachine::getSplitWordTransitionSet() { return *splitWordTransitionSet; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6459bb332f52a4693b25f43329d67ca58e0b9f75..d69d74a4ebbbb538520aa77ee111edf3b3fd5368 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: if (debug) config.printForDebug(stderr); - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition)