From c575429890046ca4cbfa110cda7c9bd80d5cb95d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 12 Oct 2021 17:32:11 +0200 Subject: [PATCH] Il oracle costs, count arc root --- Transition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Transition.py b/Transition.py index 3da12f7..967189e 100644 --- a/Transition.py +++ b/Transition.py @@ -175,7 +175,7 @@ def scoreOracleRight(config, ml, size, label) : def scoreOracleLeft(config, ml, size, label) : correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0 labelErr = 0 if label is None else (0 if config.getGold(config.stack[-size], "DEPREL") == label else 1) - return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + (1 if config.getGold(config.stack[-size], "HEAD") == 0 else 0) ################################################################################ ################################################################################ @@ -185,7 +185,7 @@ def scoreOracleShift(config, ml) : ################################################################################ def scoreOracleReduce(config, ml) : - return ml["StackRight1"] + return ml["StackRight1"] + (1 if config.getGold(config.stack[0], "HEAD") == 0 else 0) ################################################################################ ################################################################################ -- GitLab