From 3c41224b300cc98b35192df3ec8c564fa15b2852 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 7 Jul 2020 16:05:55 +0200
Subject: [PATCH] Explore different transitions only for parser

---
 reading_machine/include/Transition.hpp |  2 ++
 reading_machine/src/Transition.cpp     | 37 ++++++++++++++++++++++++++
 trainer/src/Trainer.cpp                |  6 ++++-
 3 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp
index a55f1b7..5854979 100644
--- a/reading_machine/include/Transition.hpp
+++ b/reading_machine/include/Transition.hpp
@@ -43,6 +43,7 @@ class Transition
   void initGoldReduce_strict();
   void initReduce_relaxed();
   void initEOS(int bufferIndex);
+  void initNotEOS(int bufferIndex);
   void initNothing();
   void initIgnoreChar();
   void initEndWord();
@@ -52,6 +53,7 @@ class Transition
   void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule);
   void initUppercase(std::string col, std::string obj, std::string index);
   void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex);
+  void initNothing(std::string col, std::string obj, std::string index);
   void initLowercase(std::string col, std::string obj, std::string index);
   void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex);
 
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index 27ca57b..c8d3804 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -43,6 +43,8 @@ Transition::Transition(const std::string & name)
       [this](auto sm){initEOS(std::stoi(sm[1]));}},
     {std::regex("NOTHING"),
       [this](auto){initNothing();}},
+    {std::regex("NOTEOS b\\.(.+)"),
+      [this](auto sm){initNotEOS(std::stoi(sm[1]));}},
     {std::regex("IGNORECHAR"),
       [this](auto){initIgnoreChar();}},
     {std::regex("ENDWORD"),
@@ -57,6 +59,8 @@ Transition::Transition(const std::string & name)
       [this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}},
     {std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
       [this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}},
+    {std::regex("NOTHING (.+) ([bs])\\.(.+)"),
+      [this](auto sm){(initNothing(sm[1], sm[2], sm[3]));}},
     {std::regex("LOWERCASE (.+) ([bs])\\.(.+)"),
       [this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}},
     {std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
@@ -713,6 +717,20 @@ void Transition::initEOS(int bufferIndex)
   costStatic = costDynamic;
 }
 
+void Transition::initNotEOS(int bufferIndex)
+{
+  costDynamic = [bufferIndex](const Config & config)
+  {
+    int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
+    if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
+      return std::numeric_limits<int>::max();
+
+    return 0;
+  };
+
+  costStatic = costDynamic;
+}
+
 void Transition::initDeprel(std::string label)
 {
   sequence.emplace_back(Action::deprel(label));
@@ -815,6 +833,25 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
   costStatic = costDynamic;
 }
 
+void Transition::initNothing(std::string col, std::string obj, std::string index)
+{
+  auto objectValue = Config::str2object(obj);
+  int indexValue = std::stoi(index);
+
+  costDynamic = [col, objectValue, indexValue](const Config & config)
+  {
+    int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
+    auto & expectedValue = config.getConst(col, lineIndex, 0);
+    std::string currentValue = config.getAsFeature(col, lineIndex).get();
+    if (expectedValue == currentValue)
+      return 0;
+
+    return 1;
+  };
+
+  costStatic = costDynamic;
+}
+
 void Transition::initLowercase(std::string col, std::string obj, std::string index)
 {
   auto objectValue = Config::str2object(obj);
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 32932eb..e732e79 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -128,7 +128,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     Transition * transition = nullptr;
 
     auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
-    Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()];
+
+    Transition * goldTransition = goldTransitions[0];
+    if (config.getState() == "parser")
+      goldTransitions[std::rand()%goldTransitions.size()];
+
     int nbClasses = machine.getTransitionSet(config.getState()).size();
       
     if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
-- 
GitLab