From 7c0fd54898eb26a12502b9c21433c43d0f5a0cf9 Mon Sep 17 00:00:00 2001
From: Alexis Nasr <alexis.nasr@lif.univ-mrs.fr>
Date: Fri, 6 Nov 2020 15:05:11 +0100
Subject: [PATCH] =?UTF-8?q?ajout=20de=20l'option=20v=20dans=20eval=5Fmcf?=
 =?UTF-8?q?=20pour=20le=20calcul=20de=20la=20pr=C3=A9cision=20et=20du=20ra?=
 =?UTF-8?q?ppel=20pour=20chaque=20=C3=A9tiquette?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 src/eval_mcf.py | 46 ++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 40 insertions(+), 6 deletions(-)

diff --git a/src/eval_mcf.py b/src/eval_mcf.py
index 0cdf9ca..22c2633 100644
--- a/src/eval_mcf.py
+++ b/src/eval_mcf.py
@@ -13,6 +13,12 @@ refMcdFileName = sys.argv[3]
 hypMcdFileName = sys.argv[4]
 lang = sys.argv[5]
 
+if len(sys.argv) == 7:
+    verbose = True
+else:
+    verbose = False
+
+
 #print('reading mcd from file :', refMcdFileName)
 refMcd = Mcd(refMcdFileName)
 
@@ -44,6 +50,10 @@ hypWordBuffer.readAllMcfFile()
 govCorrect = 0
 labelCorrect = 0
 
+hypTotal = {}
+refTotal = {}
+refInterHypTotal = {}
+
 hypSize = hypWordBuffer.getLength()
 for index in range(hypSize):
     refWord = refWordBuffer.getWord(index)
@@ -52,16 +62,40 @@ for index in range(hypSize):
     hypGov = hypWord.getFeat("GOV")
     refLabel = refWord.getFeat("LABEL")
     hypLabel = hypWord.getFeat("LABEL")
+    
+    if hypLabel in hypTotal :
+        hypTotal[hypLabel] += 1
+    else:
+        hypTotal[hypLabel] = 1
+        
+    if refLabel in refTotal :
+        refTotal[refLabel] += 1
+    else:
+        refTotal[refLabel] = 1
     if refGov == hypGov :
         govCorrect += 1
         if refLabel == hypLabel :
             labelCorrect += 1
-
-LAS = labelCorrect / hypSize
-UAS = govCorrect / hypSize
-
-print(lang, LAS, UAS)
-
+            if refLabel in refInterHypTotal :
+                refInterHypTotal[refLabel] += 1
+            else:
+                refInterHypTotal[refLabel] = 1
+
+LAS = 100 * labelCorrect / hypSize
+UAS = 100 * govCorrect / hypSize
+
+print("%s\t%.2f\t%.2f" % (lang, LAS, UAS))
+
+
+if verbose :
+    print("------------------------------")
+    print("label\tprec\trec\tfscore")
+    print("------------------------------")
+    for label in refInterHypTotal:
+        precision = refInterHypTotal[label] / hypTotal[label]
+        recall = refInterHypTotal[label] / refTotal[label]
+        fscore = 2 * precision * recall / (precision + recall)
+        print("%s\t%.2f\t%.2f\t%.2f" % (label, precision, recall, fscore) )
 
 
 #    print("REF GOV = ", refGov, "HYP GOV = ", hypGov, "REF LABEL = ", refLabel, "HYP LABEL = ", hypLabel)
-- 
GitLab