From 057e1d9d919282e24b65fbf540dd3f46d139943c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 18 Sep 2021 22:17:39 +0200
Subject: [PATCH] Do not read conllu file for pretrained embeddings

---
 Dicts.py | 39 +++++++++++++++++++++++++++++++++++++++
 Train.py | 11 +++++++++++
 2 files changed, 50 insertions(+)

diff --git a/Dicts.py b/Dicts.py
index be4f3fe..253bf11 100644
--- a/Dicts.py
+++ b/Dicts.py
@@ -15,6 +15,9 @@ class Dicts :
 
   def __init__(self) :
     self.dicts = {}
+    self.foundLower = {}
+    self.notFound = {}
+    self.found = {}
 
   def addDict(self, name, d) :
     if name in self.dicts :
@@ -45,6 +48,8 @@ class Dicts :
 
       splited = line.split('\t')
       for col in targetColumns :
+        if col in pretrained :
+          continue
         if col == "LETTER" :
           values = [letter for letter in splited[col2index["FORM"]]]
         else :
@@ -88,9 +93,19 @@ class Dicts :
     if col not in self.dicts :
       raise Exception("Unknown dict name '%s' among %s"%(col, str(list(self.dicts.keys()))))
     if value in self.dicts[col] :
+      if col not in self.found :
+        self.found[col] = set()
+      self.found[col].add(value)
       return self.dicts[col][value][0]
     if value.lower() in self.dicts[col] :
+      if col not in self.foundLower :
+        self.foundLower[col] = set()
+      self.foundLower[col].add(value)
       return self.dicts[col][value.lower()][0]
+
+    if col not in self.notFound :
+      self.notFound[col] = set()
+    self.notFound[col].add(value)
     return self.dicts[col][self.unkToken][0]
 
   def getElementsOf(self, col) :
@@ -103,5 +118,29 @@ class Dicts :
 
   def load(self, target) :
     self.dicts = json.load(open(target, "r"))
+
+  def printStats(self, output) :
+    for col in self.dicts :
+      total = 0
+      if col in self.found :
+        total += len(self.found[col])
+      if col in self.notFound :
+        total += len(self.notFound[col])
+      if col in self.foundLower :
+        total += len(self.foundLower[col])
+      if total == 0 :
+        continue
+      print(col, file=output)
+      if col in self.found :
+        print("Found : %.2f%%"%(100.0*len(self.found[col])/total), file=output)
+      if col in self.foundLower :
+        print("Found Lower : %.2f%%"%(100.0*len(self.foundLower[col])/total), file=output)
+      if col in self.notFound :
+        print("Not found : %.2f%%"%(100.0*len(self.notFound[col])/total), file=output)
+
+  def resetStats(self) :
+    self.found = {}
+    self.notFound = {}
+    self.foundLower = {}
 ################################################################################
 
diff --git a/Train.py b/Train.py
index 9cc05a7..d83f2a6 100644
--- a/Train.py
+++ b/Train.py
@@ -78,6 +78,12 @@ def extractExamples(debug, transitionSets, strat, config, dicts, network, dynami
 def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
   col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
 
+  statsFile = open(modelDir+"/dicts.stats", "w")
+
+  print("Train :", file=statsFile)
+  dicts.printStats(statsFile)
+  dicts.resetStats()
+
   devScore = ""
   saved = True if bestLoss is None else totalLoss < bestLoss
   bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
@@ -91,11 +97,16 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
     saved = True if bestScore is None else score > bestScore
     bestScore = score if bestScore is None else max(bestScore, score)
     devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
+    print("\nDev :", file=statsFile)
+    dicts.printStats(statsFile)
+    dicts.resetStats()
   if saved :
     torch.save(model, modelDir+"/network.pt")
   for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
     print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
 
+  statsFile.close()
+
   return bestLoss, bestScore
 ################################################################################
 
-- 
GitLab