From bb7307d4737006ad5e74572a3796d1ed560be0b0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 22 Feb 2021 17:51:02 +0100
Subject: [PATCH] Fixed oracles of non rel arc eager

---
 reading_machine/src/Transition.cpp | 48 +++++++++++++++++++++++-------
 1 file changed, 38 insertions(+), 10 deletions(-)

diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index 34b3fdb..12e4e4a 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -442,11 +442,14 @@ void Transition::initEagerLeft_rel(std::string label)
   {
     auto depIndex = config.getStack(0);
     auto govIndex = config.getWordIndex();
+    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
     int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
 
     if (label != config.getConst(Config::deprelColName, depIndex, 0))
       ++cost;
+    if (depGovIndex != std::to_string(govIndex))
+      ++cost;
 
     return cost;
   };
@@ -544,18 +547,28 @@ void Transition::initEagerLeft()
   costDynamic = [](const Config & config)
   {
     auto depIndex = config.getStack(0);
-    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
     auto govIndex = config.getWordIndex();
-
-    if (depGovIndex == std::to_string(govIndex))
-      return 0;
+    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
     int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
 
+    if (depGovIndex != std::to_string(govIndex))
+      cost += 1;
+
     return cost;
   };
 
-  costStatic = costDynamic;
+  costStatic = [](const Config & config)
+  {
+    auto depIndex = config.getStack(0);
+    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
+    auto govIndex = config.getWordIndex();
+
+    if (depGovIndex == std::to_string(govIndex))
+      return 0;
+
+    return 1;
+  };
 }
 
 void Transition::initEagerRight_rel(std::string label)
@@ -566,13 +579,17 @@ void Transition::initEagerRight_rel(std::string label)
 
   costDynamic = [label](const Config & config)
   {
+    auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
+    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
     int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
     cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
 
     if (label != config.getConst(Config::deprelColName, depIndex, 0))
       ++cost;
+    if (depGovIndex == std::to_string(govIndex))
+      ++cost;
 
     return cost;
   };
@@ -669,20 +686,31 @@ void Transition::initEagerRight()
 
   costDynamic = [](const Config & config)
   {
-    auto govIndex = config.getStack(0);
     auto depIndex = config.getWordIndex();
+    auto govIndex = config.getStack(0);
     auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
 
-    if (depGovIndex == std::to_string(govIndex))
-      return 0;
-
     int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
     cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
 
+    if (depGovIndex != std::to_string(govIndex))
+      cost += 1;
+
     return cost;
   };
 
-  costStatic = costDynamic;
+  costStatic = [](const Config & config)
+  {
+    auto govIndex = config.getStack(0);
+    auto depIndex = config.getWordIndex();
+    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
+
+    if (depGovIndex == std::to_string(govIndex))
+      return 0;
+
+    return 1;
+  };
+
 }
 
 void Transition::initReduce_strict()
-- 
GitLab