diff --git a/Transition.py b/Transition.py index 3da12f780e6e14beb269fb99f48bd969231476a0..967189e77a9e0bf2458e6877f10cce7d8442cc61 100644 --- a/Transition.py +++ b/Transition.py @@ -175,7 +175,7 @@ def scoreOracleRight(config, ml, size, label) : def scoreOracleLeft(config, ml, size, label) : correct = 1 if config.getGold(config.stack[-size], "HEAD") == config.wordIndex else 0 labelErr = 0 if label is None else (0 if config.getGold(config.stack[-size], "DEPREL") == label else 1) - return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + return sum([ml["StackRight"+str(n)] for n in range(1,size+1)]) - correct + labelErr + (1 if config.getGold(config.stack[-size], "HEAD") == 0 else 0) ################################################################################ ################################################################################ @@ -185,7 +185,7 @@ def scoreOracleShift(config, ml) : ################################################################################ def scoreOracleReduce(config, ml) : - return ml["StackRight1"] + return ml["StackRight1"] + (1 if config.getGold(config.stack[0], "HEAD") == 0 else 0) ################################################################################ ################################################################################