diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py index 2c8b0913e060c18e3dac80de43e9007e7f43f609..3aab97b97ef900dddaded192c8e273395bba9c87 100755 --- a/scripts/conll18_ud_eval.py +++ b/scripts/conll18_ud_eval.py @@ -364,15 +364,22 @@ def load_conllu(file) : # Evaluate the gold and system treebanks (loaded using load_conllu). def evaluate(gold_ud, system_ud, extraColumns) : class Score : - def __init__(self, gold_total, system_total, correct, aligned_total=None) : + def __init__(self, gold_total, system_total, correct, aligned_total=None, isNumeric=False, R2=None) : self.correct = correct self.gold_total = gold_total self.system_total = system_total self.aligned_total = aligned_total - self.precision = correct / system_total if system_total else 0.0 - self.recall = correct / gold_total if gold_total else 0.0 - self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 - self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total + if isNumeric : + self.precision = 0 + self.recall = R2 + self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total + else : + self.precision = 100*correct / system_total if system_total else 0.0 + self.recall = 100*correct / gold_total if gold_total else 0.0 + self.f1 = 2 * 100*correct / (system_total + gold_total) if system_total + gold_total else 0.0 + self.aligned_accuracy = 100*correct / aligned_total if aligned_total else aligned_total + class AlignmentWord : def __init__(self, gold_word, system_word) : self.gold_word = gold_word @@ -419,21 +426,42 @@ def evaluate(gold_ud, system_ud, extraColumns) : return word def gold_aligned_system(word) : return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None + isNumericOnly = True + for words in alignment.matched_words : + if filter_fn is None or filter_fn(words.gold_word) : + goldItem = key_fn(words.gold_word, gold_aligned_gold) + systemItem = key_fn(words.system_word, gold_aligned_system) + if (not isinstance(systemItem, str) or '.' not in systemItem or not is_float(systemItem)) or (not isinstance(goldItem, str) or '.' not in goldItem or not is_float(goldItem)) : + isNumericOnly = False + correct = 0 errors = [] + goldValues = [] + predictedValues = [] for words in alignment.matched_words : if filter_fn is None or filter_fn(words.gold_word) : goldItem = key_fn(words.gold_word, gold_aligned_gold) systemItem = key_fn(words.system_word, gold_aligned_system) - if (not isinstance(systemItem, str) or '.' not in systemItem or not is_float(systemItem)) or (not isinstance(goldItem, str) or '.' not in goldItem or not is_float(goldItem)) : + if not isNumericOnly : if goldItem == systemItem : correct += 1 else : errors.append(words) else : - correct -= abs(float(goldItem) - float(systemItem)) - - return [Score(gold, system, correct, aligned), errors] + correct -= abs(float(goldItem) - float(systemItem))**2 + goldValues.append(float(goldItem)) + predictedValues.append(float(systemItem)) + + R2 = None + if isNumericOnly : + goldMean = sum(goldValues) / len(goldValues) + E1 = 0.0 + E2 = 0.0 + for i in range(len(predictedValues)) : + E1 += (goldValues[i]-predictedValues[i])**2 + E2 += (goldMean-predictedValues[i])**2 + R2 = 1 - E1/E2 + return [Score(gold, system, correct, aligned, isNumeric=isNumericOnly, R2=R2), errors] def beyond_end(words, i, multiword_span_end) : if i >= len(words) : @@ -725,10 +753,10 @@ def main() : print("{:{}}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( metric, maxColNameSize, - 100 * evaluation[metric][0].precision, - 100 * evaluation[metric][0].recall, - 100 * evaluation[metric][0].f1, - "{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else "" + evaluation[metric][0].precision, + evaluation[metric][0].recall, + evaluation[metric][0].f1, + "{:10.2f}".format(evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else "" )) for id2 in range(len(errors_list)) :