From 14bcdc4e9a4eb4c161b9013e649eeaff00ef047e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 2 Mar 2020 22:32:50 +0100
Subject: [PATCH] Added actions for feats prediction

---
 reading_machine/include/Action.hpp     |  2 +
 reading_machine/include/Transition.hpp |  2 +
 reading_machine/src/Action.cpp         | 73 ++++++++++++++++++++++++++
 reading_machine/src/Transition.cpp     | 39 ++++++++++++++
 4 files changed, 116 insertions(+)

diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp
index a20f68a..994fe9c 100644
--- a/reading_machine/include/Action.hpp
+++ b/reading_machine/include/Action.hpp
@@ -49,7 +49,9 @@ class Action
   static Action moveWordIndex(int movement);
   static Action moveCharacterIndex(int movement);
   static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
+  static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
   static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis);
+  static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition);
   static Action pushWordIndexOnStack();
   static Action popStack();
   static Action emptyStack();
diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp
index c3a4589..c7309a6 100644
--- a/reading_machine/include/Transition.hpp
+++ b/reading_machine/include/Transition.hpp
@@ -17,11 +17,13 @@ class Transition
   private :
 
   void initWrite(std::string colName, std::string object, std::string index, std::string value);
+  void initAdd(std::string colName, std::string object, std::string index, std::string value);
   void initShift();
   void initLeft(std::string label);
   void initRight(std::string label);
   void initReduce();
   void initEOS();
+  void initNothing();
 
   public :
 
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index 6cf1678..ca95473 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -87,6 +87,79 @@ Action Action::addHypothesis(const std::string & colName, std::size_t lineIndex,
   return {Type::Write, apply, undo, appliable}; 
 }
 
+Action Action::addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition)
+{
+  auto apply = [colName, lineIndex, addition](Config & config, Action &)
+  {
+    auto & current = config.getLastNotEmptyHyp(colName, lineIndex);
+    current = util::isEmpty(current) ? addition : '|' + addition;
+  };
+
+  auto undo = [colName, lineIndex](Config & config, Action &)
+  {
+    std::string newValue = config.getLastNotEmpty(colName, lineIndex);
+    while (!newValue.empty() and newValue.back() == '|')
+      newValue.pop_back();
+    if (!newValue.empty())
+      newValue.pop_back();
+    config.getLastNotEmpty(colName, lineIndex) = newValue;
+  };
+
+  auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)
+  {
+    if (!config.has(colName, lineIndex, 0))
+      return false;
+    auto & current = config.getLastNotEmptyHypConst(colName, lineIndex);
+    auto splited = util::split(current.get(), '|');
+    for (auto & part : splited)
+      if (part == addition)
+        return false;
+    return true;
+  };
+
+  return {Type::Write, apply, undo, appliable}; 
+}
+
+Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition)
+{
+  auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
+  {
+    int lineIndex = 0;
+    if (object == Object::Buffer)
+      lineIndex = config.getWordIndex() + relativeIndex;
+    else
+      lineIndex = config.getStack(relativeIndex);
+
+    return addToHypothesis(colName, lineIndex, addition).apply(config, a);
+  };
+
+  auto undo = [colName, object, relativeIndex](Config & config, Action & a)
+  {
+    int lineIndex = 0;
+    if (object == Object::Buffer)
+      lineIndex = config.getWordIndex() + relativeIndex;
+    else
+      lineIndex = config.getStack(relativeIndex);
+
+    return addToHypothesis(colName, lineIndex, "").undo(config, a);
+  };
+
+  auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a)
+  {
+    int lineIndex = 0;
+    if (object == Object::Buffer)
+      lineIndex = config.getWordIndex() + relativeIndex;
+    else if (config.hasStack(relativeIndex))
+      lineIndex = config.getStack(relativeIndex);
+    else
+      return false;
+
+    return addToHypothesis(colName, lineIndex, addition).appliable(config, a);
+  };
+
+  return {Type::Write, apply, undo, appliable}; 
+}
+
 Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis)
 {
   auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a)
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index f18450c..259e7f4 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -5,11 +5,13 @@ Transition::Transition(const std::string & name)
 {
   std::regex nameRegex("(<(.+)> )?(.+)");
   std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
+  std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)");
   std::regex shiftRegex("SHIFT");
   std::regex reduceRegex("REDUCE");
   std::regex leftRegex("LEFT (.+)");
   std::regex rightRegex("RIGHT (.+)");
   std::regex eosRegex("EOS");
+  std::regex nothingRegex("NOTHING");
 
   try
   {
@@ -22,6 +24,8 @@ Transition::Transition(const std::string & name)
 
   if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
     return;
+  if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);}))
+    return;
   if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();}))
     return;
   if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();}))
@@ -32,6 +36,8 @@ Transition::Transition(const std::string & name)
     return;
   if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();}))
     return;
+  if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();}))
+    return;
 
   throw std::invalid_argument("no match");
 
@@ -89,6 +95,39 @@ void Transition::initWrite(std::string colName, std::string object, std::string
   };
 }
 
+void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
+{
+  auto objectValue = Action::str2object(object);
+  int indexValue = std::stoi(index);
+
+  sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
+
+  cost = [colName, objectValue, indexValue, value](const Config & config)
+  {
+    int lineIndex = 0;
+    if (objectValue == Action::Object::Buffer)
+      lineIndex = config.getWordIndex() + indexValue;
+    else
+      lineIndex = config.getStack(indexValue);
+
+    auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
+
+    for (auto & part : gold)
+      if (part == value)
+        return 0;
+
+    return 1;
+  };
+}
+
+void Transition::initNothing()
+{
+  cost = [](const Config &)
+  {
+    return 0;
+  };
+}
+
 void Transition::initShift()
 {
   sequence.emplace_back(Action::pushWordIndexOnStack());
-- 
GitLab