Skip to content
Snippets Groups Projects
Commit 334d4661 authored by Franck Dary's avatar Franck Dary
Browse files

Error comparison for eval script working for upos

parent 97c328f2
No related branches found
No related tags found
No related merge requests found
......@@ -539,7 +539,6 @@ def evaluate_wrapper(args) :
################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
class Error :
def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
self.gold = gold_word
......@@ -547,7 +546,7 @@ def compute_errors(gold_file, system_file, evaluation, metric) :
self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
# TODO : do it for other than UPOS
self.type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
self.type = self.gold.columns[UPOS]+"->"+self.pred.columns[UPOS]
def __str__(self) :
result = []
gold_lines = []
......@@ -562,10 +561,15 @@ def compute_errors(gold_file, system_file, evaluation, metric) :
return "\n".join(result)
class Errors :
def __init__(self, metric) :
def __init__(self, metric, errors1=None, errors2=None) :
self.types = []
self.nb_errors = 0
self.metric = metric
if errors1 is not None and errors2 is not None :
for type in errors1.types :
for error in type.errors :
if not errors2.has(error) :
self.add(error)
def __len__(self) :
return self.nb_errors
def add(self, error) :
......@@ -576,6 +580,10 @@ def compute_errors(gold_file, system_file, evaluation, metric) :
return
self.types.append(ErrorType(error.type))
self.types[-1].add(error)
def has(self, error) :
for t in self.types :
if t.type == error.type :
return t.has(error)
def sort(self) :
self.types.sort(key=len, reverse=True)
......@@ -587,7 +595,16 @@ def compute_errors(gold_file, system_file, evaluation, metric) :
return len(self.errors)
def add(self, error) :
self.errors.append(error)
def has(self, error) :
for other_error in self.errors :
if other_error.gold == error.gold :
return True
return False
################################################################################
################################################################################
def compute_errors(gold_file, system_file, evaluation, metric) :
errors = Errors(metric)
for alignment_word in evaluation[metric][1] :
gold = alignment_word.gold_word
......@@ -621,6 +638,7 @@ def main() :
# Evaluate
gold_ud, evaluations = evaluate_wrapper(args)
errors_by_file = []
examples_list = []
for id1 in range(len(evaluations)) :
(system_ud, evaluation) = evaluations[id1]
......@@ -659,24 +677,55 @@ def main() :
errors.sort()
print("Most frequent errors for metric '{}' :".format(errors.metric))
print("{:>12} {:>5} {:>6} {}\n {:->37}".format("ID", "NB", "%AGE", "GOLD->SYSTEM", ""))
print("{:>12} {:5} {:6.2f}%".format("Total", len(errors), 100))
for id3 in range(len(errors.types[:10])) :
error_type = errors.types[:10][id3]
t = error_type.type
nb = len(error_type)
percent = 100.0*nb/len(errors)
id = ":".join(map(str,[id1,id2,id3,"*"]))
print("{:>12} {:5} {:5.2f}% {}".format(id, nb, percent, t))
print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
for id4 in range(len(error_type)) :
examples_list.append((":".join(map(str,[id1,id2,id3,id4])), error_type.errors[id4]))
print("")
if len(errors_by_file[0]) > 0 :
for id1 in range(len(evaluations)) :
(system1_ud, evaluation) = evaluations[id1]
for id2 in range(len(evaluations)) :
if id1 == id2 :
continue
(system2_ud, evaluation) = evaluations[id2]
errors1 = errors_by_file[id1]
errors2 = errors_by_file[id2]
if len(errors1) > 0 :
print("{} Error comparison {}".format("*"*31, "*"*31))
print("{:>30} : {}".format("These errors are present in", system1_ud.filename))
print("{:>30} : {}".format("and not in", system2_ud.filename))
for id3 in range(len(errors1)) :
metric = errors1[id3].metric
errors_diff = Errors(metric, errors1[id3], errors2[id3])
errors_diff.sort()
print("{:>12} {:5} {:6.2f}%".format("Total", len(errors_diff), 100))
for id4 in range(len(errors_diff.types[:10])) :
error_type = errors_diff.types[:10][id4]
t = error_type.type
nb = len(error_type)
percent = 100.0*nb/len(errors)
id = ":".join(map(str,["d"+str(id1),id3,id4,"*"]))
print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
for id5 in range(len(error_type)) :
examples_list.append((":".join(map(str,["d"+str(id1),id3,id4,id5])), error_type.errors[id5]))
print("")
if len(examples_list) > 0 :
print("{}List of all errors by their ID{}".format("*"*25,"*"*25))
print("{}{:^30}{}\n".format("*"*25,"Format is GOLD | PREDICTED","*"*25))
for i1 in range(len(errors_by_file)) :
for i2 in range(len(errors_by_file[i1])) :
for i3 in range(len(errors_by_file[i1][i2].types)) :
for i4 in range(len(errors_by_file[i1][i2].types[i3].errors)) :
print("ID="+":".join(map(str,[i1,i2,i3,i4])))
print(errors_by_file[i1][i2].types[i3].errors[i4])
for (id,error) in examples_list :
print("ID="+id)
print(error)
print("")
################################################################################
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment