Skip to content
Snippets Groups Projects
Commit 82591bb2 authored by Franck Dary's avatar Franck Dary
Browse files

Updated eval script to print most common error types

parent 0d8b5c00
No related branches found
No related tags found
No related merge requests found
......@@ -94,9 +94,11 @@ from __future__ import print_function
import argparse
import io
import os
import sys
import unicodedata
import unittest
import math
# CoNLL-U column names
ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
......@@ -157,6 +159,8 @@ def load_conllu(file) :
self.sentences = []
# List of UDSpan instances with start&end indices into `words`.
self.sentences_words = []
# Name of the file this representation has been extracted from.
self.filename = ""
class UDSpan :
def __init__(self, start, end) :
self.start = start
......@@ -189,6 +193,7 @@ def load_conllu(file) :
self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS
ud = UDRepresentation()
ud.filename = file.name
# Load the CoNLL-U file
index, sentence_start = 0, None
......@@ -511,34 +516,60 @@ def load_conllu_file(path) :
def evaluate_wrapper(args) :
# Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file)
system_files = [load_conllu_file(args.system_file)]
if args.system_file2 is not None :
print("TODO")
#TODO
system_files.append(load_conllu_file(args.system_file2))
return evaluate(gold_ud, system_ud), [gold_ud, system_ud]
return gold_ud, [(system, evaluate(gold_ud, system)) for system in system_files]
################################################################################
################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
errors = {}
class Error :
def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
self.gold = gold_word
self.pred = system_word
self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
# TODO : do it for other than UPOS
self.type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
class Errors :
def __init__(self, metric) :
self.types = []
self.nb_errors = 0
self.metric = metric
def __len__(self) :
return self.nb_errors
def add(self, error) :
self.nb_errors += 1
for t in self.types :
if t.type == error.type :
t.add(error)
return
self.types.append(ErrorType(error.type))
self.types[-1].add(error)
def sort(self) :
self.types.sort(key=len, reverse=True)
class ErrorType :
def __init__(self, error_type) :
self.type = error_type
self.errors = []
def __len__(self) :
return len(self.errors)
def add(self, error) :
self.errors.append(error)
errors = Errors(metric)
for alignment_word in evaluation[metric][1] :
gold = alignment_word.gold_word
pred = alignment_word.system_word
error_type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
error = Error(gold_file, system_file, gold, pred, metric)
gold_sentence_start = gold_file.sentences_words[gold.sentence].start
gold_sentence_end = gold_file.sentences_words[gold.sentence].end
pred_sentence_start = system_file.sentences_words[pred.sentence].start
pred_sentence_end = system_file.sentences_words[pred.sentence].end
error = [gold, pred, gold_file.words[gold_sentence_start:gold_sentence_end], system_file.words[pred_sentence_start:pred_sentence_end]]
if error_type not in errors :
errors[error_type] = []
errors[error_type].append(error)
errors.add(error)
return errors
################################################################################
......@@ -556,13 +587,20 @@ def main() :
help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
parser.add_argument("--system_file2",
help="Name of another CoNLL-U file with predicted data, for error comparison.")
parser.add_argument("--enumerate_errors", "-e", default=None,
help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").")
args = parser.parse_args()
errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',')
# Evaluate
evaluation, files = evaluate_wrapper(args)
gold_ud, evaluations = evaluate_wrapper(args)
for (system_ud, evaluation) in evaluations :
fnamelen = len(system_ud.filename)
print("*"*math.ceil((80-2-fnamelen)/2),system_ud.filename,"*"*math.floor((80-2-fnamelen)/2))
# Compute errors
errors = compute_errors(files[0], files[1], evaluation, "UPOS")
errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics]
# Print the evaluation
if args.counts :
......@@ -587,6 +625,16 @@ def main() :
100 * evaluation[metric][0].f1,
"{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
))
for errors in errors_list :
errors.sort()
print("")
print("Most frequent errors for metric '{}' :".format(errors.metric))
for error_type in errors.types[:10] :
t = error_type.type
nb = len(error_type)
percent = 100.0*nb/len(errors)
print("{:5} {:5.2f}% {}".format(nb, percent, t))
################################################################################
......@@ -595,56 +643,3 @@ if __name__ == "__main__" :
main()
################################################################################
################################################################################
# Tests, which can be executed with `python -m unittest conll18_ud_eval`.
class TestAlignment(unittest.TestCase) :
@staticmethod
def _load_words(words) :
"""Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors."""
lines, num_words = [], 0
for w in words :
parts = w.split(" ")
if len(parts) == 1 :
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
else :
lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
for part in parts[1:] :
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
def _test_exception(self, gold, system) :
self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
def _test_ok(self, gold, system, correct) :
metrics = evaluate(self._load_words(gold), self._load_words(system))
gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
(correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
def test_exception(self) :
self._test_exception(["a"], ["b"])
def test_equal(self) :
self._test_ok(["a"], ["a"], 1)
self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
def test_equal_with_multiword(self) :
self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
def test_alignment(self) :
self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment