From 78a246d70252dec64580e1b3e0df8511e0b7fbb2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 13 Jan 2020 15:13:02 +0100
Subject: [PATCH] Working skeleton of oracle decoding

---
 dev/src/dev.cpp                            | 46 +++++++++++++++++-----
 reading_machine/include/Classifier.hpp     |  1 +
 reading_machine/include/ReadingMachine.hpp |  2 +
 reading_machine/include/Strategy.hpp       |  2 +
 reading_machine/include/Transition.hpp     |  1 +
 reading_machine/include/TransitionSet.hpp  |  4 +-
 reading_machine/src/Classifier.cpp         |  5 +++
 reading_machine/src/Config.cpp             |  9 +++--
 reading_machine/src/ReadingMachine.cpp     | 10 +++++
 reading_machine/src/Strategy.cpp           |  7 ++++
 reading_machine/src/Transition.cpp         |  5 +++
 reading_machine/src/TransitionSet.cpp      | 36 ++++++++++++++---
 12 files changed, 107 insertions(+), 21 deletions(-)

diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index d33d370..c538b98 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -8,23 +8,49 @@
 
 int main(int argc, char * argv[])
 {
-  /*
-  BaseConfig goldConfig(argv[3], argv[1], argv[2]);
+  if (argc != 5)
+  {
+    fmt::print(stderr, "needs 4 arguments.\n");
+    exit(1);
+  }
+
+  std::string machineFile = argv[1];
+  std::string mcdFile = argv[2];
+  std::string tsvFile = argv[3];
+  //std::string rawFile = argv[4];
+  std::string rawFile = "";
+
+  ReadingMachine machine(machineFile);
 
+  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
   SubConfig config(goldConfig);
-  auto other = config;
 
-  while (config.moveWordIndex(1))
+  config.setState(machine.getStrategy().getInitialState());
+
+  config.printForDebug(stderr);
+
+  while (true)
   {
+    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
+    if (!transition)
+      util::myThrow("No transition appliable !");
+
+    fmt::print(stderr, "Transition : {}\n", transition->getName());
+    transition->apply(config);
+
+    auto movement = machine.getStrategy().getMovement(config, transition->getName());
+    if (movement == Strategy::endMovement)
+      break;
+
+    config.setState(movement.first);
+    if (!config.moveWordIndex(movement.second))
+      util::myThrow("Cannot move word index !");
+
     if (config.needsUpdate())
       config.update();
-  }
-
-  fmt::print(stderr, "ok\n");
-  std::scanf("%*c");
-  */
 
-  ReadingMachine machine(argv[1]);
+    config.printForDebug(stderr);
+  }
 
   return 0;
 }
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index ce61d5a..5d38ae8 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -16,6 +16,7 @@ class Classifier
   public :
 
   Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
+  TransitionSet & getTransitionSet();
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index dc4be73..ab5d794 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -16,6 +16,8 @@ class ReadingMachine
   public :
 
   ReadingMachine(const std::string & filename);
+  TransitionSet & getTransitionSet();
+  Strategy & getStrategy();
 };
 
 #endif
diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp
index e1bc580..10d9628 100644
--- a/reading_machine/include/Strategy.hpp
+++ b/reading_machine/include/Strategy.hpp
@@ -21,6 +21,7 @@ class Strategy
   std::map<std::pair<std::string, std::string>, std::string> edges;
   std::map<std::string, bool> isDone;
   std::vector<std::string> defaultCycle;
+  std::string initialState{"UNDEFINED"};
 
   private :
 
@@ -31,6 +32,7 @@ class Strategy
 
   Strategy(const std::vector<std::string_view> & lines);
   std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
+  const std::string getInitialState() const;
 };
 
 #endif
diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp
index ecd90df..cbdf21f 100644
--- a/reading_machine/include/Transition.hpp
+++ b/reading_machine/include/Transition.hpp
@@ -23,6 +23,7 @@ class Transition
   void apply(Config & config);
   bool appliable(const Config & config) const;
   int getCost(const Config & config) const;
+  const std::string & getName() const;
 };
 
 #endif
diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp
index 04c1e90..188ce70 100644
--- a/reading_machine/include/TransitionSet.hpp
+++ b/reading_machine/include/TransitionSet.hpp
@@ -11,13 +11,13 @@ class TransitionSet
   private :
 
   std::vector<Transition> transitions;
-  std::unordered_map<std::string, std::size_t> name2index;
   std::optional<std::size_t> defaultAction;
 
   public :
 
   TransitionSet(const std::string & filename);
-  std::vector<std::pair<Transition &, int>> getAppliableTransitionsCosts(const Config & c);
+  std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
+  Transition * getBestAppliableTransition(const Config & c);
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 47100c7..d446be2 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -7,3 +7,8 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c
   this->nn = MLP(topology);
 }
 
+TransitionSet & Classifier::getTransitionSet()
+{
+  return *transitionSet;
+}
+
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 1a29fff..f4d1001 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -80,7 +80,9 @@ void Config::printForDebug(FILE * dest) const
 {
   static constexpr int windowSize = 5;
   static constexpr int lettersWindowSize = 40;
-  static constexpr int maxWordLength = 10;
+  static constexpr int maxWordLength = 7;
+
+  fmt::print(dest, "\n");
 
   int firstLineToPrint = wordIndex;
   int lastLineToPrint = wordIndex;
@@ -138,7 +140,8 @@ void Config::printForDebug(FILE * dest) const
   fmt::print(dest, "{}\n", longLine);
   for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++)
     fmt::print(dest, "{}", getLetter(index));
-  fmt::print(dest, "\n{}\n", longLine);
+  if (rawInput.size())
+    fmt::print(dest, "\n{}\n", longLine);
   fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr);
   fmt::print(dest, "{}\n", longLine);
 
@@ -151,8 +154,6 @@ void Config::printForDebug(FILE * dest) const
     if (toPrint[line].back() == EOSSymbol1)
       fmt::print(dest, "\n");
   }
-
-  fmt::print(dest, "\n");
 }
 
 Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex)
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index cc484eb..f2d49be 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -34,3 +34,13 @@ ReadingMachine::ReadingMachine(const std::string & filename)
   } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));}
 }
 
+TransitionSet & ReadingMachine::getTransitionSet()
+{
+  return classifier->getTransitionSet();
+}
+
+Strategy & ReadingMachine::getStrategy()
+{
+  return *strategy;
+}
+
diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp
index 3b49cd4..ca045f1 100644
--- a/reading_machine/src/Strategy.cpp
+++ b/reading_machine/src/Strategy.cpp
@@ -16,6 +16,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
     {
       key = std::pair<std::string,std::string>(splited[0], "");
       value = splited[1];
+      if (defaultCycle.empty())
+        initialState = splited[0];
       defaultCycle.emplace_back(value);
     }
     else if (splited.size() == 3)
@@ -100,3 +102,8 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c
   return {defaultCycle.back(), 1};
 }
 
+const std::string Strategy::getInitialState() const
+{
+  return initialState;
+}
+
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index d8b9c8a..46147d4 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -61,3 +61,8 @@ void Transition::initWrite(std::string colName, std::string object, std::string
   };
 }
 
+const std::string & Transition::getName() const
+{
+  return name;
+}
+
diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp
index 467ad42..a1a1f89 100644
--- a/reading_machine/src/TransitionSet.cpp
+++ b/reading_machine/src/TransitionSet.cpp
@@ -1,4 +1,5 @@
 #include "TransitionSet.hpp"
+#include <limits>
 
 TransitionSet::TransitionSet(const std::string & filename)
 {
@@ -23,14 +24,14 @@ TransitionSet::TransitionSet(const std::string & filename)
   std::fclose(file);
 }
 
-std::vector<std::pair<Transition &, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c)
+std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c)
 {
-  using Pair = std::pair<Transition &, int>;
+  using Pair = std::pair<Transition*, int>;
   std::vector<Pair> appliableTransitions;
 
-  for (auto & transition : transitions)
-    if (transition.appliable(c))
-      appliableTransitions.emplace_back(transition, transition.getCost(c));
+  for (unsigned int i = 0; i < transitions.size(); i++)
+    if (transitions[i].appliable(c))
+      appliableTransitions.emplace_back(&transitions[i], transitions[i].getCost(c));
 
   std::sort(appliableTransitions.begin(), appliableTransitions.end(), 
   [](const Pair & a, const Pair & b)
@@ -41,3 +42,28 @@ std::vector<std::pair<Transition &, int>> TransitionSet::getAppliableTransitions
   return appliableTransitions;
 }
 
+Transition * TransitionSet::getBestAppliableTransition(const Config & c)
+{
+  Transition * result = nullptr;
+  int bestCost = std::numeric_limits<int>::max();
+
+  for (unsigned int i = 0; i < transitions.size(); i++)
+  {
+    if (!transitions[i].appliable(c))
+      continue;
+
+    int cost = transitions[i].getCost(c);
+
+    if (cost == 0)
+      return &transitions[i];
+
+    if (cost < bestCost)
+    {
+      result = &transitions[i];
+      bestCost = cost;
+    }
+  }
+
+  return result;
+}
+
-- 
GitLab