From 89fe9c355ce91cffd2b2b61e57a252f4128fd5cf Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 21 Jul 2021 16:33:16 +0200
Subject: [PATCH] Added and integred producer

---
 decoder/include/Beam.hpp           |  3 ++-
 decoder/include/Decoder.hpp        |  3 ++-
 decoder/include/Producer.hpp       | 21 +++++++++++++++++++++
 decoder/src/Beam.cpp               |  7 ++++++-
 decoder/src/Decoder.cpp            |  4 ++--
 decoder/src/MacaonDecode.cpp       | 23 ++++++++++++++++++-----
 decoder/src/Producer.cpp           | 22 ++++++++++++++++++++++
 reading_machine/include/Config.hpp |  5 +++++
 reading_machine/src/Action.cpp     |  2 +-
 reading_machine/src/BaseConfig.cpp |  3 ---
 reading_machine/src/Config.cpp     | 21 +++++++++++++++++++++
 reading_machine/src/Strategy.cpp   |  2 ++
 trainer/src/MacaonTrain.cpp        |  5 +++--
 13 files changed, 105 insertions(+), 16 deletions(-)
 create mode 100644 decoder/include/Producer.hpp
 create mode 100644 decoder/src/Producer.cpp

diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp
index 1dd4018..e38d38f 100644
--- a/decoder/include/Beam.hpp
+++ b/decoder/include/Beam.hpp
@@ -5,6 +5,7 @@
 #include <string>
 #include "BaseConfig.hpp"
 #include "ReadingMachine.hpp"
+#include "Producer.hpp"
 
 class Beam
 {
@@ -40,7 +41,7 @@ class Beam
 
   Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine);
   Element & operator[](std::size_t index);
-  void update(ReadingMachine & machine, bool debug);
+  void update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer);
   bool isEnded() const;
 };
 
diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index fe8c870..0156757 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -4,6 +4,7 @@
 #include <filesystem>
 #include "ReadingMachine.hpp"
 #include "SubConfig.hpp"
+#include "Producer.hpp"
 
 class Decoder
 {
@@ -25,7 +26,7 @@ class Decoder
   public :
 
   Decoder(ReadingMachine & machine);
-  std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
+  std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer);
   void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
   std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
   std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp
new file mode 100644
index 0000000..ec95402
--- /dev/null
+++ b/decoder/include/Producer.hpp
@@ -0,0 +1,21 @@
+#ifndef PRODUCER__H
+#define PRODUCER__H
+
+#include <filesystem>
+#include "Config.hpp"
+
+class Producer
+{
+  private :
+
+  static constexpr int maxNb = 100;
+  int curNb = 0;
+
+  public :
+
+  Producer(std::filesystem::path path);
+
+  bool apply(Config & config);
+};
+
+#endif
diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index c932733..1a5e151 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -22,7 +22,7 @@ Beam::Element & Beam::operator[](std::size_t index)
   return elements[index];
 }
 
-void Beam::update(ReadingMachine & machine, bool debug)
+void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer> & producer)
 {
   ended = true;
   auto currentNbElements = elements.size();
@@ -37,6 +37,11 @@ void Beam::update(ReadingMachine & machine, bool debug)
 
     ended = false;
 
+    if (producer.has_value() and elements[index].config.getState() == "tokenizer")
+      elements[index].config.setRawInputStatus(not producer.value().apply(elements[index].config));
+    if (not producer.has_value())
+      elements[index].config.setRawInputStatus(true);
+
     auto & classifier = *machine.getClassifier(elements[index].config.getState());
 
     if (machine.hasSplitWordTransitionSet())
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 5394280..7b1e7b0 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 {
 }
 
-std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
+std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer)
 {
   constexpr int printInterval = 50;
 
@@ -20,7 +20,7 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
   {
     while (!beam.isEnded())
     {
-      beam.update(machine, debug);
+      beam.update(machine, debug, producer);
       ++totalNbExamplesProcessed;
 
       if (printAdvancement)
diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index 22e715f..f964061 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -4,6 +4,7 @@
 #include "util.hpp"
 #include "Decoder.hpp"
 #include "Submodule.hpp"
+#include "Producer.hpp"
 
 po::options_description MacaonDecode::getOptionsDescription()
 {
@@ -16,7 +17,9 @@ po::options_description MacaonDecode::getOptionsDescription()
     ("inputTSV", po::value<std::string>(),
       "File containing the text to decode, TSV file")
     ("inputTXT", po::value<std::string>(),
-      "File containing the text to decode, raw text file");
+      "File containing the text to decode, raw text file")
+    ("inputSES", po::value<std::string>(),
+      "File containing a list of actions that will fill the input tape");
 
   po::options_description opt("Optional");
   opt.add_options()
@@ -55,7 +58,7 @@ po::variables_map MacaonDecode::checkOptions(po::options_description & od)
   try {po::notify(vm);}
   catch(std::exception& e) {util::myThrow(e.what());}
 
-  if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
+  if (vm.count("inputTSV") + vm.count("inputTXT") + vm.count("inputSES") != 1)
   {
     std::stringstream ss;
     ss << od;
@@ -76,6 +79,7 @@ int MacaonDecode::main()
   auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
   auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
   auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
+  auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : "";
   auto mcd = variables["mcd"].as<std::string>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
@@ -117,7 +121,9 @@ int MacaonDecode::main()
     }
     else
     {
-      if (rawInputs.size())
+      if (inputSES.size())
+        configs.emplace_back(mcd, noTsv, util::utf8string(), std::vector<int>());
+      else if (rawInputs.size())
         configs.emplace_back(mcd, noTsv, rawInputs[0], std::vector<int>());
       else
         configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>());
@@ -132,11 +138,18 @@ int MacaonDecode::main()
       std::for_each(std::execution::par, configs.begin(), configs.end(),
         [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
         {
-          decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
+          decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
         });
     }
     else
-      decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement);
+    {
+      if (not inputSES.empty())
+      {
+        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES)));
+      }
+      else
+        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
+    }
 
     for (unsigned int i = 0; i < configs.size(); i++)
       configs[i].print(stdout, i == 0);
diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp
new file mode 100644
index 0000000..8d564dd
--- /dev/null
+++ b/decoder/src/Producer.cpp
@@ -0,0 +1,22 @@
+#include "Producer.hpp"
+
+Producer::Producer(std::filesystem::path)
+{
+}
+
+bool Producer::apply(Config & config)
+{
+  if (util::choiceWithProbability(0.05))
+  {
+    config.rawInputAdd(".");
+    config.rawInputAdd(" ");
+  }
+  else if (util::choiceWithProbability(0.8))
+    config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26)));
+  else
+    config.rawInputAdd(" ");
+
+  curNb++;  
+  return curNb < maxNb;
+}
+
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index ed0a29d..db8eb10 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -49,6 +49,7 @@ class Config
   protected :
 
   Utf8String rawInput;
+  bool rawInputIsComplete;
   std::size_t wordIndex{0};
   std::size_t characterIndex{0};
   std::size_t currentSentenceStartRawInput{0};
@@ -120,6 +121,10 @@ class Config
   util::String & getFirstEmpty(const std::string & colName, int lineIndex);
   bool hasCharacter(int letterIndex) const;
   const util::utf8char & getLetter(int letterIndex) const;
+  bool getRawInputStatus() const;
+  void setRawInputStatus(bool status);
+  void rawInputPop();
+  void rawInputAdd(util::utf8char letter);
   void addToHistory(const std::string & transition);
   void addToStack(std::size_t index);
   void popStack();
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index d978112..72e33ec 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -444,7 +444,7 @@ Action Action::endWord()
     config.setCurrentWordId(config.getCurrentWordId()+1);
     addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a);
     
-    if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0))
+    if (!(config.rawInputOnlySeparatorsLeft() and config.getRawInputStatus()) and !config.has(0,config.getWordIndex()+1,0))
       config.addLines(1);
   };
 
diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp
index 61488eb..086f7f0 100644
--- a/reading_machine/src/BaseConfig.cpp
+++ b/reading_machine/src/BaseConfig.cpp
@@ -156,9 +156,6 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
 
 BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
 {
-  if (sentences.empty() and rawInput.empty())
-    util::myThrow("sentences and rawInput can't be both empty");
-
   createColumns(mcd);
 
   if (not rawInput.empty())
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 69ae400..22ca045 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -14,6 +14,7 @@ Config::Config(const Config & other)
   this->strategy.reset(other.strategy ? new Strategy(*other.strategy) : nullptr);
 
   this->rawInput = other.rawInput;
+  this->rawInputIsComplete = other.rawInputIsComplete;
   this->wordIndex = other.wordIndex;
   this->characterIndex = other.characterIndex;
   this->state = other.state;
@@ -396,6 +397,26 @@ const util::utf8char & Config::getLetter(int letterIndex) const
   return rawInput[letterIndex];
 }
 
+bool Config::getRawInputStatus() const
+{
+  return rawInputIsComplete;
+}
+
+void Config::setRawInputStatus(bool status)
+{
+  rawInputIsComplete = status;
+}
+
+void Config::rawInputPop()
+{
+  rawInput.pop_back();
+}
+
+void Config::rawInputAdd(util::utf8char letter)
+{
+  rawInput.push_back(letter);
+}
+
 bool Config::isMultiword(std::size_t lineIndex) const
 {
   return hasColIndex(idColName) && std::string(getConst(idColName, lineIndex, 0)).find('-') != std::string::npos;
diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp
index b22b9e6..d23440e 100644
--- a/reading_machine/src/Strategy.cpp
+++ b/reading_machine/src/Strategy.cpp
@@ -116,7 +116,9 @@ bool Strategy::Block::isFinished(const Config & c, const Movement & movement)
     if (condition == EndCondition::CannotMove)
     {
       if (c.canMoveWordIndex(movement.second))
+      {
         return false;
+      }
     }
 
   return true;
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index d760d1f..08832ba 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -4,6 +4,7 @@
 #include "util.hpp"
 #include "NeuralNetwork.hpp"
 #include "WordEmbeddings.hpp"
+#include "Producer.hpp"
 
 namespace po = boost::program_options;
 
@@ -332,14 +333,14 @@ int MacaonTrain::main()
         std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
           [&decoder, debug, printAdvancement](BaseConfig & devConfig)
           {
-            decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
+            decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>());
           });
         NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
         machine.to(NeuralNetworkImpl::getDevice());
       }
       else
       {
-        decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement);
+        decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>());
       }
 
       std::vector<const Config *> devConfigsPtrs;
-- 
GitLab