From c575429890046ca4cbfa110cda7c9bd80d5cb95d Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 12 Oct 2021 17:32:11 +0200
Subject: [PATCH] Il oracle costs, count arc root

---
 Transition.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/Transition.py b/Transition.py
index 3da12f7..967189e 100644
--- a/Transition.py
+++ b/Transition.py
@@ -175,7 +175,7 @@ def scoreOracleRight(config, ml, size, label) :
 def scoreOracleLeft(config, ml, size, label) :
   correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0
   labelErr = 0 if label is None else (0 if config.getGold(config.stack[-size], "DEPREL") == label else 1)
-  return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr
+  return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + (1 if config.getGold(config.stack[-size], "HEAD") == 0 else 0)
 ################################################################################
 
 ################################################################################
@@ -185,7 +185,7 @@ def scoreOracleShift(config, ml) :
 
 ################################################################################
 def scoreOracleReduce(config, ml) :
-  return ml["StackRight1"]
+  return ml["StackRight1"] + (1 if config.getGold(config.stack[0], "HEAD") == 0 else 0)
 ################################################################################
 
 ################################################################################
-- 
GitLab