From 334d4661a41294705532ab3ed456eb280fc48d05 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 30 May 2020 13:38:52 +0200
Subject: [PATCH] Error comparison for eval script working for upos

---
 scripts/conll18_ud_eval.py | 163 ++++++++++++++++++++++++-------------
 1 file changed, 106 insertions(+), 57 deletions(-)

diff --git a/scripts/conll18_ud_eval.py b/scripts/conll18_ud_eval.py
index 4e36e00..2554a2b 100755
--- a/scripts/conll18_ud_eval.py
+++ b/scripts/conll18_ud_eval.py
@@ -539,55 +539,72 @@ def evaluate_wrapper(args) :
 
 
 ################################################################################
-def compute_errors(gold_file, system_file, evaluation, metric) :
-  class Error :
-    def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
-      self.gold = gold_word
-      self.pred = system_word
-      self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
-      self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
-      # TODO : do it for other than UPOS
-      self.type = gold.columns[UPOS]+"->"+pred.columns[UPOS]
-    def __str__(self) :
-      result = []
-      gold_lines = []
-      pred_lines = []
-      for word in self.gold_sentence :
-        gold_lines.append((">" if word == self.gold else " ") + " ".join(filter_columns(word.columns)))
-      for word in self.pred_sentence :
-        pred_lines.append((">" if word == self.pred else " ") + " ".join(filter_columns(word.columns)))
-           
-      for index in range(max(len(gold_lines), len(pred_lines))) :
-        result.append("{} | {}".format(gold_lines[index] if index < len(gold_lines) else "", pred_lines[index] if index < len(pred_lines) else ""))
-      return "\n".join(result)
-
-  class Errors :
-    def __init__(self, metric) :
-      self.types = []
-      self.nb_errors = 0
-      self.metric = metric
-    def __len__(self) : 
-      return self.nb_errors
-    def add(self, error) :
-      self.nb_errors += 1
-      for t in self.types :
-        if t.type == error.type :
-          t.add(error)
-          return
-      self.types.append(ErrorType(error.type))
-      self.types[-1].add(error)
-    def sort(self) :
-      self.types.sort(key=len, reverse=True)
-
-  class ErrorType :
-    def __init__(self, error_type) :
-      self.type = error_type
-      self.errors = []
-    def __len__(self) :
-      return len(self.errors)
-    def add(self, error) :
-      self.errors.append(error)
+class Error :
+  def __init__(self, gold_file, system_file, gold_word, system_word, metric) :
+    self.gold = gold_word
+    self.pred = system_word
+    self.gold_sentence = gold_file.words[gold_file.sentences_words[self.gold.sentence].start:gold_file.sentences_words[self.gold.sentence].end]
+    self.pred_sentence = system_file.words[system_file.sentences_words[self.pred.sentence].start:system_file.sentences_words[self.pred.sentence].end]
+    # TODO : do it for other than UPOS
+    self.type = self.gold.columns[UPOS]+"->"+self.pred.columns[UPOS]
+  def __str__(self) :
+    result = []
+    gold_lines = []
+    pred_lines = []
+    for word in self.gold_sentence :
+      gold_lines.append((">" if word == self.gold else " ") + " ".join(filter_columns(word.columns)))
+    for word in self.pred_sentence :
+      pred_lines.append((">" if word == self.pred else " ") + " ".join(filter_columns(word.columns)))
+         
+    for index in range(max(len(gold_lines), len(pred_lines))) :
+      result.append("{} | {}".format(gold_lines[index] if index < len(gold_lines) else "", pred_lines[index] if index < len(pred_lines) else ""))
+    return "\n".join(result)
+
+class Errors :
+  def __init__(self, metric, errors1=None, errors2=None) :
+    self.types = []
+    self.nb_errors = 0
+    self.metric = metric
+    if errors1 is not None and errors2 is not None :
+      for type in errors1.types :
+        for error in type.errors :
+          if not errors2.has(error) :
+            self.add(error)
+  def __len__(self) : 
+    return self.nb_errors
+  def add(self, error) :
+    self.nb_errors += 1
+    for t in self.types :
+      if t.type == error.type :
+        t.add(error)
+        return
+    self.types.append(ErrorType(error.type))
+    self.types[-1].add(error)
+  def has(self, error) :
+    for t in self.types :
+      if t.type == error.type :
+        return t.has(error)
+  def sort(self) :
+    self.types.sort(key=len, reverse=True)
+
+class ErrorType :
+  def __init__(self, error_type) :
+    self.type = error_type
+    self.errors = []
+  def __len__(self) :
+    return len(self.errors)
+  def add(self, error) :
+    self.errors.append(error)
+  def has(self, error) :
+    for other_error in self.errors :
+      if other_error.gold == error.gold :
+        return True
+    return False
+################################################################################
+
 
+################################################################################
+def compute_errors(gold_file, system_file, evaluation, metric) :
   errors = Errors(metric)
   for alignment_word in evaluation[metric][1] :
     gold = alignment_word.gold_word
@@ -621,6 +638,7 @@ def main() :
   # Evaluate
   gold_ud, evaluations = evaluate_wrapper(args)
   errors_by_file = []
+  examples_list = []
 
   for id1 in range(len(evaluations)) :
     (system_ud, evaluation) = evaluations[id1]
@@ -659,25 +677,56 @@ def main() :
       errors.sort()
       print("Most frequent errors for metric '{}' :".format(errors.metric))
       print("{:>12} {:>5} {:>6} {}\n {:->37}".format("ID", "NB", "%AGE", "GOLD->SYSTEM", ""))
+
+      print("{:>12} {:5} {:6.2f}%".format("Total", len(errors), 100))
       for id3 in range(len(errors.types[:10])) :
         error_type = errors.types[:10][id3]
         t = error_type.type
         nb = len(error_type)
         percent = 100.0*nb/len(errors)
         id = ":".join(map(str,[id1,id2,id3,"*"]))
-        print("{:>12} {:5} {:5.2f}% {}".format(id, nb, percent, t))
+        print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
+        for id4 in range(len(error_type)) :
+          examples_list.append((":".join(map(str,[id1,id2,id3,id4])), error_type.errors[id4]))
       print("")
 
-  if len(errors_by_file[0]) > 0 :
+  for id1 in range(len(evaluations)) :
+    (system1_ud, evaluation) = evaluations[id1]
+    for id2 in range(len(evaluations)) :
+      if id1 == id2 :
+        continue
+      (system2_ud, evaluation) = evaluations[id2]
+      errors1 = errors_by_file[id1]
+      errors2 = errors_by_file[id2]
+
+      if len(errors1) > 0 :
+        print("{} Error comparison {}".format("*"*31, "*"*31))
+        print("{:>30} : {}".format("These errors are present in", system1_ud.filename))
+        print("{:>30} : {}".format("and not in", system2_ud.filename))
+      for id3 in range(len(errors1)) :
+        metric = errors1[id3].metric
+        errors_diff = Errors(metric, errors1[id3], errors2[id3])
+        errors_diff.sort()
+        print("{:>12} {:5} {:6.2f}%".format("Total", len(errors_diff), 100))
+        for id4 in range(len(errors_diff.types[:10])) :
+          error_type = errors_diff.types[:10][id4]
+          t = error_type.type
+          nb = len(error_type)
+          percent = 100.0*nb/len(errors)
+          id = ":".join(map(str,["d"+str(id1),id3,id4,"*"]))
+          print("{:>12} {:5} {:6.2f}% {}".format(id, nb, percent, t))
+          for id5 in range(len(error_type)) :
+            examples_list.append((":".join(map(str,["d"+str(id1),id3,id4,id5])), error_type.errors[id5]))
+        print("")
+
+  if len(examples_list) > 0 :
     print("{}List of all errors by their ID{}".format("*"*25,"*"*25))
     print("{}{:^30}{}\n".format("*"*25,"Format is GOLD | PREDICTED","*"*25))
-  for i1 in range(len(errors_by_file)) :
-    for i2 in range(len(errors_by_file[i1])) :
-      for i3 in range(len(errors_by_file[i1][i2].types)) :
-        for i4 in range(len(errors_by_file[i1][i2].types[i3].errors)) :
-          print("ID="+":".join(map(str,[i1,i2,i3,i4])))
-          print(errors_by_file[i1][i2].types[i3].errors[i4])
-          print("")
+
+  for (id,error) in examples_list :
+    print("ID="+id)
+    print(error)
+    print("")
 ################################################################################
 
 
-- 
GitLab