From e7e29b090188941bb94c0f1251420857df504279 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 22 Jul 2021 10:53:57 +0200
Subject: [PATCH] Added output ses

---
 decoder/include/Decoder.hpp  |  2 +-
 decoder/include/Producer.hpp |  7 ++++++-
 decoder/src/Beam.cpp         |  4 ++++
 decoder/src/Decoder.cpp      |  7 ++++++-
 decoder/src/MacaonDecode.cpp | 15 +++++++++++----
 decoder/src/Producer.cpp     | 30 ++++++++++++++++++++++++++++--
 trainer/src/MacaonTrain.cpp  |  8 +++++---
 7 files changed, 61 insertions(+), 12 deletions(-)

diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index 0156757..4518c72 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -26,7 +26,7 @@ class Decoder
   public :
 
   Decoder(ReadingMachine & machine);
-  std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer);
+  std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer);
   void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
   std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
   std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
diff --git a/decoder/include/Producer.hpp b/decoder/include/Producer.hpp
index ec95402..a5b331a 100644
--- a/decoder/include/Producer.hpp
+++ b/decoder/include/Producer.hpp
@@ -11,11 +11,16 @@ class Producer
   static constexpr int maxNb = 100;
   int curNb = 0;
 
+  std::filesystem::path input, output;
+  std::vector<std::string> sequence;
+
   public :
 
-  Producer(std::filesystem::path path);
+  Producer(std::filesystem::path input, std::filesystem::path output);
 
   bool apply(Config & config);
+  void addConfigToSequence(const Config & config);
+  void writeOutputFile() const;
 };
 
 #endif
diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 1a5e151..90500df 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -140,6 +140,10 @@ void Beam::update(ReadingMachine & machine, bool debug, std::optional<Producer>
 
     config.setState(movement.first);
     config.moveWordIndexRelaxed(movement.second);
+
+
+    if (producer.has_value())
+      producer.value().addConfigToSequence(config);
   }
 
   if (debug)
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 7b1e7b0..93493bf 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 {
 }
 
-std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> producer)
+std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement, std::optional<Producer> & producer)
 {
   constexpr int printInterval = 50;
 
@@ -42,6 +42,11 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
   {
     auto eosTransition = Transition("EOS b.0");
     eosTransition.apply(baseConfig);
+    baseConfig.addToHistory(eosTransition.getName());
+
+    if (producer.has_value())
+      producer.value().addConfigToSequence(baseConfig);
+
     if (debug)
     {
       fmt::print(stderr, "Forcing EOS transition\n");
diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index f964061..4762257 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -33,6 +33,8 @@ po::options_description MacaonDecode::getOptionsDescription()
       "Size of the beam during beam search")
     ("beamThreshold", po::value<float>()->default_value(0.1),
       "Minimal probability an action must have to be considered in the beam search")
+    ("outputSES", po::value<std::string>()->default_value(""),
+      "Output file for enriched SES")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -80,6 +82,7 @@ int MacaonDecode::main()
   auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
   auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
   auto inputSES = variables.count("inputSES") ? variables["inputSES"].as<std::string>() : "";
+  auto outputSES = variables.count("outputSES") ? variables["outputSES"].as<std::string>() : "";
   auto mcd = variables["mcd"].as<std::string>();
   bool debug = variables.count("debug") == 0 ? false : true;
   bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
@@ -88,6 +91,8 @@ int MacaonDecode::main()
   auto beamSize = variables["beamSize"].as<int>();
   auto beamThreshold = variables["beamThreshold"].as<float>();
 
+  auto noProducer = std::optional<Producer>();
+
   torch::globalContext().setBenchmarkCuDNN(true);
   Submodule::setReloadPretrained(reloadPretrained);
 
@@ -136,19 +141,21 @@ int MacaonDecode::main()
       NeuralNetworkImpl::setDevice(torch::kCPU);
       machine.to(NeuralNetworkImpl::getDevice());
       std::for_each(std::execution::par, configs.begin(), configs.end(),
-        [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
+        [&decoder, debug, printAdvancement, beamSize, beamThreshold, &noProducer](BaseConfig & config)
         {
-          decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
+          decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement, noProducer);
         });
     }
     else
     {
       if (not inputSES.empty())
       {
-        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>(Producer(inputSES)));
+        auto producer = std::optional<Producer>(Producer(inputSES, outputSES));
+        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, producer);
+        producer.value().writeOutputFile();
       }
       else
-        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, std::optional<Producer>());
+        decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement, noProducer);
     }
 
     for (unsigned int i = 0; i < configs.size(); i++)
diff --git a/decoder/src/Producer.cpp b/decoder/src/Producer.cpp
index 8d564dd..9906095 100644
--- a/decoder/src/Producer.cpp
+++ b/decoder/src/Producer.cpp
@@ -1,22 +1,48 @@
 #include "Producer.hpp"
 
-Producer::Producer(std::filesystem::path)
+Producer::Producer(std::filesystem::path input, std::filesystem::path output) : input(input), output(output)
 {
 }
 
+// Add one or more characters to config's rawInput.
+// Returns false if we are finished and true if we have events remaining.
 bool Producer::apply(Config & config)
 {
   if (util::choiceWithProbability(0.05))
   {
     config.rawInputAdd(".");
     config.rawInputAdd(" ");
+    sequence.push_back("<addletter \".\">");
+    sequence.push_back("<addletter \" \">");
   }
   else if (util::choiceWithProbability(0.8))
-    config.rawInputAdd(fmt::format("{}", (char) ('a'+rand()%26)));
+  {
+    auto letter = fmt::format("{}", (char) ('a'+rand()%26));
+    config.rawInputAdd(letter);
+    sequence.push_back(fmt::format("<addletter \"{}\">", letter));
+  }
   else
+  {
     config.rawInputAdd(" ");
+    sequence.push_back("<addletter \" \">");
+  }
 
   curNb++;  
   return curNb < maxNb;
 }
 
+// Adds an event in the sequence that represent the current config state.
+void Producer::addConfigToSequence(const Config & config)
+{
+  sequence.push_back(fmt::format("<action \"{}\">", config.getHistory(0)));
+}
+
+// Writes the entire sequence to the output file.
+void Producer::writeOutputFile() const
+{
+  std::FILE * outputFile = output.empty() ? stdout : std::fopen(output.c_str(), "w");
+
+  for (auto & event : sequence)
+    fmt::print(outputFile, "{}\n", event);
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 08832ba..7229ee4 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -151,6 +151,8 @@ int MacaonTrain::main()
   std::srand(seed);
   torch::manual_seed(seed);
 
+  auto noProducer = std::optional<Producer>();
+
   auto trainStrategy = parseTrainStrategy(trainStrategyStr);
 
   torch::globalContext().setBenchmarkCuDNN(true);
@@ -331,16 +333,16 @@ int MacaonTrain::main()
         NeuralNetworkImpl::setDevice(torch::kCPU);
         machine.to(NeuralNetworkImpl::getDevice());
         std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
-          [&decoder, debug, printAdvancement](BaseConfig & devConfig)
+          [&decoder, debug, printAdvancement, &noProducer](BaseConfig & devConfig)
           {
-            decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, std::optional<Producer>());
+            decoder.decode(devConfig, 1, 0.0, debug, printAdvancement, noProducer);
           });
         NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
         machine.to(NeuralNetworkImpl::getDevice());
       }
       else
       {
-        decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, std::optional<Producer>());
+        decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement, noProducer);
       }
 
       std::vector<const Config *> devConfigsPtrs;
-- 
GitLab