Skip to content
Snippets Groups Projects
Commit 82591bb2 authored by Franck Dary's avatar Franck Dary
Browse files

Updated eval script to print most common error types

parent 0d8b5c00
Branches
No related tags found
No related merge requests found
...@@ -94,9 +94,11 @@ from __future__ import print_function ...@@ -94,9 +94,11 @@ from __future__ import print_function
import argparse import argparse
import io import io
import os
import sys import sys
import unicodedata import unicodedata
import unittest import unittest
import math
# CoNLL-U column names # CoNLL-U column names
ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10) ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
...@@ -157,6 +159,8 @@ def load_conllu(file) : ...@@ -157,6 +159,8 @@ def load_conllu(file) :
self.sentences = [] self.sentences = []
# List of UDSpan instances with start&end indices into `words`. # List of UDSpan instances with start&end indices into `words`.
self.sentences_words = [] self.sentences_words = []
# Name of the file this representation has been extracted from.
self.filename = ""
class UDSpan : class UDSpan :
def __init__(self, start, end) : def __init__(self, start, end) :
self.start = start self.start = start
...@@ -189,6 +193,7 @@ def load_conllu(file) : ...@@ -189,6 +193,7 @@ def load_conllu(file) :
self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS
ud = UDRepresentation() ud = UDRepresentation()
ud.filename = file.name
# Load the CoNLL-U file # Load the CoNLL-U file
index, sentence_start = 0, None index, sentence_start = 0, None
...@@ -511,34 +516,60 @@ def load_conllu_file(path) : ...@@ -511,34 +516,60 @@ def load_conllu_file(path) :
def evaluate_wrapper(args) : def evaluate_wrapper(args) :
# Load CoNLL-U files # Load CoNLL-U files
gold_ud = load_conllu_file(args.gold_file) gold_ud = load_conllu_file(args.gold_file)
system_ud = load_conllu_file(args.system_file) system_files = [load_conllu_file(args.system_file)]
if args.system_file2 is not None : if args.system_file2 is not None :
print("TODO") system_files.append(load_conllu_file(args.system_file2))
#TODO
return evaluate(gold_ud, system_ud), [gold_ud, system_ud] return gold_ud, [(system, evaluate(gold_ud, system)) for system in system_files]
################################################################################ ################################################################################
################################################################################ ################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) : def compute_errors(gold_file, system_file, evaluation, metric) :
errors = {} class Error :
def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
self.gold = gold_word
self.pred = system_word
self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
# TODO : do it for other than UPOS
self.type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
class Errors :
def __init__(self, metric) :
self.types = []
self.nb_errors = 0
self.metric = metric
def __len__(self) :
return self.nb_errors
def add(self, error) :
self.nb_errors += 1
for t in self.types :
if t.type == error.type :
t.add(error)
return
self.types.append(ErrorType(error.type))
self.types[-1].add(error)
def sort(self) :
self.types.sort(key=len, reverse=True)
class ErrorType :
def __init__(self, error_type) :
self.type = error_type
self.errors = []
def __len__(self) :
return len(self.errors)
def add(self, error) :
self.errors.append(error)
errors = Errors(metric)
for alignment_word in evaluation[metric][1] : for alignment_word in evaluation[metric][1] :
gold = alignment_word.gold_word gold = alignment_word.gold_word
pred = alignment_word.system_word pred = alignment_word.system_word
error_type = gold.columns[UPOS]+"->"+pred.columns[UPOS] error = Error(gold_file, system_file, gold, pred, metric)
gold_sentence_start = gold_file.sentences_words[gold.sentence].start errors.add(error)
gold_sentence_end = gold_file.sentences_words[gold.sentence].end
pred_sentence_start = system_file.sentences_words[pred.sentence].start
pred_sentence_end = system_file.sentences_words[pred.sentence].end
error = [gold, pred, gold_file.words[gold_sentence_start:gold_sentence_end], system_file.words[pred_sentence_start:pred_sentence_end]]
if error_type not in errors :
errors[error_type] = []
errors[error_type].append(error)
return errors return errors
################################################################################ ################################################################################
...@@ -556,37 +587,54 @@ def main() : ...@@ -556,37 +587,54 @@ def main() :
help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
parser.add_argument("--system_file2", parser.add_argument("--system_file2",
help="Name of another CoNLL-U file with predicted data, for error comparison.") help="Name of another CoNLL-U file with predicted data, for error comparison.")
parser.add_argument("--enumerate_errors", "-e", default=None,
help="Comma separated list of column names for which to enumerate errors (e.g. \"UPOS,FEATS\").")
args = parser.parse_args() args = parser.parse_args()
errors_metrics = [] if args.enumerate_errors is None else args.enumerate_errors.split(',')
# Evaluate # Evaluate
evaluation, files = evaluate_wrapper(args) gold_ud, evaluations = evaluate_wrapper(args)
# Compute errors for (system_ud, evaluation) in evaluations :
errors = compute_errors(files[0], files[1], evaluation, "UPOS") fnamelen = len(system_ud.filename)
print("*"*math.ceil((80-2-fnamelen)/2),system_ud.filename,"*"*math.floor((80-2-fnamelen)/2))
# Print the evaluation # Compute errors
if args.counts : errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics]
print("Metric | Correct | Gold | Predicted | Aligned")
else : # Print the evaluation
print("Metric | Precision | Recall | F1 Score | AligndAcc")
print("-----------+-----------+-----------+-----------+-----------")
for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] :
if args.counts : if args.counts :
print("{:11}|{:10} |{:10} |{:10} |{:10}".format( print("Metric | Correct | Gold | Predicted | Aligned")
metric,
evaluation[metric][0].correct,
evaluation[metric][0].gold_total,
evaluation[metric][0].system_total,
evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
))
else : else :
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( print("Metric | Precision | Recall | F1 Score | AligndAcc")
metric, print("-----------+-----------+-----------+-----------+-----------")
100 * evaluation[metric][0].precision, for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"] :
100 * evaluation[metric][0].recall, if args.counts :
100 * evaluation[metric][0].f1, print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
"{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else "" metric,
)) evaluation[metric][0].correct,
evaluation[metric][0].gold_total,
evaluation[metric][0].system_total,
evaluation[metric][0].aligned_total or (evaluation[metric][0].correct if metric == "Words" else "")
))
else :
print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
metric,
100 * evaluation[metric][0].precision,
100 * evaluation[metric][0].recall,
100 * evaluation[metric][0].f1,
"{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
))
for errors in errors_list :
errors.sort()
print("")
print("Most frequent errors for metric '{}' :".format(errors.metric))
for error_type in errors.types[:10] :
t = error_type.type
nb = len(error_type)
percent = 100.0*nb/len(errors)
print("{:5} {:5.2f}% {}".format(nb, percent, t))
################################################################################ ################################################################################
...@@ -595,56 +643,3 @@ if __name__ == "__main__" : ...@@ -595,56 +643,3 @@ if __name__ == "__main__" :
main() main()
################################################################################ ################################################################################
################################################################################
# Tests, which can be executed with `python -m unittest conll18_ud_eval`.
class TestAlignment(unittest.TestCase) :
@staticmethod
def _load_words(words) :
"""Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors."""
lines, num_words = [], 0
for w in words :
parts = w.split(" ")
if len(parts) == 1 :
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
else :
lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
for part in parts[1:] :
num_words += 1
lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
def _test_exception(self, gold, system) :
self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
def _test_ok(self, gold, system, correct) :
metrics = evaluate(self._load_words(gold), self._load_words(system))
gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
(correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
def test_exception(self) :
self._test_exception(["a"], ["b"])
def test_equal(self) :
self._test_ok(["a"], ["a"], 1)
self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
def test_equal_with_multiword(self) :
self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
def test_alignment(self) :
self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
################################################################################
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment