From b9a3230cafb0b516a85c06e53b87a2d161f83c0f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 5 Apr 2020 19:05:20 +0200
Subject: [PATCH] Made splitwords transition set optional in definition of
 reading machine

---
 decoder/src/Decoder.cpp                    | 3 ++-
 reading_machine/include/ReadingMachine.hpp | 3 ++-
 reading_machine/src/ReadingMachine.cpp     | 8 +++++++-
 trainer/src/Trainer.cpp                    | 3 ++-
 4 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 3a2caef..78e6204 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
     if (debug)
       config.printForDebug(stderr);
 
-    config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+    if (machine.hasSplitWordTransitionSet())
+      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
 
     auto dictState = machine.getDict(config.getState()).getState();
     auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index dd0098d..dcece9a 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -26,7 +26,7 @@ class ReadingMachine
   std::map<std::string, Dict> dicts;
   std::set<std::string> predicted;
 
-  std::unique_ptr<TransitionSet> splitWordTransitionSet;
+  std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
 
   private :
 
@@ -39,6 +39,7 @@ class ReadingMachine
   ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
   TransitionSet & getTransitionSet();
   TransitionSet & getSplitWordTransitionSet();
+  bool hasSplitWordTransitionSet() const;
   Strategy & getStrategy();
   Dict & getDict(const std::string & state);
   std::map<std::string, Dict> & getDicts();
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 9d4a854..94de1cb 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
 
     --curLine;
 
-    util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm)
+    util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
     {
       this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
+      curLine++;
     });
 
     if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
@@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet()
   return classifier->getTransitionSet();
 }
 
+bool ReadingMachine::hasSplitWordTransitionSet() const
+{
+  return splitWordTransitionSet.get() != nullptr;
+}
+
 TransitionSet & ReadingMachine::getSplitWordTransitionSet()
 {
   return *splitWordTransitionSet;
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 6459bb3..d69d74a 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
     if (debug)
       config.printForDebug(stderr);
 
-    config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
+    if (machine.hasSplitWordTransitionSet())
+      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
 
     auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
     if (!transition)
-- 
GitLab