From 00844dd5f6e3342625a64ae5aa7bd00366685f57 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 12 Jan 2021 16:54:29 +0100 Subject: [PATCH] Fixed R2 computation in eval script --- scripts/conll18_ud_eval.py | 44 ++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py index 3aab97b..123f252 100755 --- a/scripts/conll18_ud_eval.py +++ b/scripts/conll18_ud_eval.py @@ -365,20 +365,21 @@ def load_conllu(file) : def evaluate(gold_ud, system_ud, extraColumns) : class Score : def __init__(self, gold_total, system_total, correct, aligned_total=None, isNumeric=False, R2=None) : - self.correct = correct + self.correct = correct[0] self.gold_total = gold_total self.system_total = system_total self.aligned_total = aligned_total if isNumeric : self.precision = 0 self.recall = R2 - self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 - self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total + self.f1 = correct[1] / gold_total if gold_total else 0.0 + self.aligned_accuracy = correct[0] / gold_total if gold_total else 0.0 + else : - self.precision = 100*correct / system_total if system_total else 0.0 - self.recall = 100*correct / gold_total if gold_total else 0.0 - self.f1 = 2 * 100*correct / (system_total + gold_total) if system_total + gold_total else 0.0 - self.aligned_accuracy = 100*correct / aligned_total if aligned_total else aligned_total + self.precision = 100*correct[0] / system_total if system_total else 0.0 + self.recall = 100*correct[0] / gold_total if gold_total else 0.0 + self.f1 = 2 * 100*correct[0] / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = 100*correct[0] / aligned_total if aligned_total else aligned_total class AlignmentWord : def __init__(self, gold_word, system_word) : @@ -406,7 +407,7 @@ def evaluate(gold_ud, system_ud, extraColumns) : si += 1 gi += 1 - return [Score(len(gold_spans), len(system_spans), correct)] + return [Score(len(gold_spans), len(system_spans), [correct])] def alignment_score(alignment, key_fn=None, filter_fn=None) : if filter_fn is not None : @@ -420,7 +421,7 @@ def evaluate(gold_ud, system_ud, extraColumns) : if key_fn is None : # Return score for whole aligned words - return [Score(gold, system, aligned)] + return [Score(gold, system, [aligned])] def gold_aligned_gold(word) : return word @@ -434,7 +435,7 @@ def evaluate(gold_ud, system_ud, extraColumns) : if (not isinstance(systemItem, str) or '.' not in systemItem or not is_float(systemItem)) or (not isinstance(goldItem, str) or '.' not in goldItem or not is_float(goldItem)) : isNumericOnly = False - correct = 0 + correct = [0,0] errors = [] goldValues = [] predictedValues = [] @@ -444,23 +445,30 @@ def evaluate(gold_ud, system_ud, extraColumns) : systemItem = key_fn(words.system_word, gold_aligned_system) if not isNumericOnly : if goldItem == systemItem : - correct += 1 + correct[0] += 1 else : errors.append(words) - else : - correct -= abs(float(goldItem) - float(systemItem))**2 + # WARNING: this script ignore examples where gold value == 0.0 + elif float(goldItem) != 0.0 : + correct[0] -= abs(float(goldItem) - float(systemItem))**1 + correct[1] -= abs(float(goldItem) - float(systemItem))**2 goldValues.append(float(goldItem)) predictedValues.append(float(systemItem)) R2 = None if isNumericOnly : goldMean = sum(goldValues) / len(goldValues) - E1 = 0.0 - E2 = 0.0 + predMean = sum(predictedValues) / len(predictedValues) + numerator = 0.0 + denom1 = 0.0 + denom2 = 0.0 for i in range(len(predictedValues)) : - E1 += (goldValues[i]-predictedValues[i])**2 - E2 += (goldMean-predictedValues[i])**2 - R2 = 1 - E1/E2 + numerator += (predictedValues[i]-predMean)*(goldValues[i]-goldMean) + denom1 += (predictedValues[i]-predMean)**2 + denom2 += (goldValues[i]-goldMean)**2 + + pearson = numerator/((denom1**0.5)*(denom2**0.5)) + R2 = pearson**2 return [Score(gold, system, correct, aligned, isNumeric=isNumericOnly, R2=R2), errors] def beyond_end(words, i, multiword_span_end) : -- GitLab