From 97c328f228de9347f7c1b728e2f0ec426419f81e Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 29 May 2020 17:25:23 +0200
Subject: [PATCH] Prining list of errors in eval script

---
 scripts/conll18_ud_eval.py | 53 ++++++++++++++++++++++++++++++++++----
 1 file changed, 48 insertions(+), 5 deletions(-)

diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py
index 903f57d..4e36e00 100755
--- a/scripts/conll18_ud_eval.py
+++ b/scripts/conll18_ud_eval.py
@@ -122,6 +122,19 @@ UNIVERSAL_FEATURES = {
 }
 
 
+################################################################################
+def filter_columns(columns) :
+  res = []
+  indexes = [0, 1, 3, 6, 7]
+  lengths = [4, 8, 8, 4, 8]
+
+  for (content, max_len) in [(columns[indexes[index]], lengths[index]) for index in range(len(indexes))] :
+    res.append(("{:"+str(max_len)+"}").format(content if len(content) <= max_len else "{}…{}".format(content[0:math.ceil((max_len-1)/2)],content[-((max_len-1)//2):])))
+
+  return res
+################################################################################
+
+
 ################################################################################
 # UD Error is used when raising exceptions in this module
 class UDError(Exception) :
@@ -535,6 +548,18 @@ def compute_errors(gold_file, system_file, evaluation, metric) :
       self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
       # TODO : do it for other than UPOS
       self.type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
+    def __str__(self) :
+      result = []
+      gold_lines = []
+      pred_lines = []
+      for word in self.gold_sentence :
+        gold_lines.append((">" if word == self.gold else " ") + " ".join(filter_columns(word.columns)))
+      for word in self.pred_sentence :
+        pred_lines.append((">" if word == self.pred else " ") + " ".join(filter_columns(word.columns)))
+           
+      for index in range(max(len(gold_lines), len(pred_lines))) :
+        result.append("{} | {}".format(gold_lines[index] if index < len(gold_lines) else "", pred_lines[index] if index < len(pred_lines) else ""))
+      return "\n".join(result)
 
   class Errors :
     def __init__(self, metric) :
@@ -595,12 +620,15 @@ def main() :
 
   # Evaluate
   gold_ud, evaluations = evaluate_wrapper(args)
+  errors_by_file = []
 
-  for (system_ud, evaluation) in evaluations :
+  for id1 in range(len(evaluations)) :
+    (system_ud, evaluation) = evaluations[id1]
     fnamelen = len(system_ud.filename)
     print("*"*math.ceil((80-2-fnamelen)/2),system_ud.filename,"*"*math.floor((80-2-fnamelen)/2))
     # Compute errors
     errors_list = [compute_errors(gold_ud, system_ud, evaluation, metric) for metric in errors_metrics]
+    errors_by_file.append(errors_list)
   
     # Print the evaluation
     if args.counts :
@@ -626,15 +654,30 @@ def main() :
           "{:10.2f}".format(100 * evaluation[metric][0].aligned_accuracy) if evaluation[metric][0].aligned_accuracy is not None else ""
         ))
   
-    for errors in errors_list :
+    for id2 in range(len(errors_list)) :
+      errors = errors_list[id2]
       errors.sort()
-      print("")
       print("Most frequent errors for metric '{}' :".format(errors.metric))
-      for error_type in errors.types[:10] :
+      print("{:>12} {:>5} {:>6} {}\n {:->37}".format("ID", "NB", "%AGE", "GOLD->SYSTEM", ""))
+      for id3 in range(len(errors.types[:10])) :
+        error_type = errors.types[:10][id3]
         t = error_type.type
         nb = len(error_type)
         percent = 100.0*nb/len(errors)
-        print("{:5} {:5.2f}% {}".format(nb, percent, t))
+        id = ":".join(map(str,[id1,id2,id3,"*"]))
+        print("{:>12} {:5} {:5.2f}% {}".format(id, nb, percent, t))
+      print("")
+
+  if len(errors_by_file[0]) > 0 :
+    print("{}List of all errors by their ID{}".format("*"*25,"*"*25))
+    print("{}{:^30}{}\n".format("*"*25,"Format is GOLD | PREDICTED","*"*25))
+  for i1 in range(len(errors_by_file)) :
+    for i2 in range(len(errors_by_file[i1])) :
+      for i3 in range(len(errors_by_file[i1][i2].types)) :
+        for i4 in range(len(errors_by_file[i1][i2].types[i3].errors)) :
+          print("ID="+":".join(map(str,[i1,i2,i3,i4])))
+          print(errors_by_file[i1][i2].types[i3].errors[i4])
+          print("")
 ################################################################################
 
 
-- 
GitLab