From d36e1f08b59a8fb6572d9cb731ed3302385ac62f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 24 Apr 2021 16:20:25 +0200
Subject: [PATCH] Improved debug prints

---
 Config.py | 22 +++++++++++++++++-----
 Decode.py |  8 ++++----
 2 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/Config.py b/Config.py
index b3f3e6c..9e208f9 100644
--- a/Config.py
+++ b/Config.py
@@ -75,11 +75,18 @@ class Config :
     return len(self.lines)
 
   def printForDebug(self, output) :
+    printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
+    left = 5
+    right = 5
     print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
-    for lineIndex in range(len(self.lines)) :
-      print("%s"%("=>" if lineIndex == self.wordIndex else "  "), end="", file=output)
-      toPrint = []
+    toPrint = []
+    for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
+      if lineIndex not in range(len(self.lines)) :
+        continue
+      toPrint.append(["%s"%("=>" if lineIndex == self.wordIndex else "  ")])
       for colIndex in range(len(self.lines[lineIndex])) :
+        if self.index2col[colIndex] not in printedCols :
+          continue
         value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
         if value == "" :
           value = "_"
@@ -87,8 +94,13 @@ class Config :
           value = self.getAsFeature(int(value), "ID")
         elif self.index2col[colIndex] == "HEAD" and value == "-1":
           value = "0"
-        toPrint.append(value)
-      print("\t".join(toPrint), file=output)
+        toPrint[-1].append(value)
+    maxCol = [max([len(toPrint[i][j]) for i in range(len(toPrint))]) for j in range(len(toPrint[0]))]
+    for i in range(len(toPrint)) :
+      for j in range(len(toPrint[i])) :
+        toPrint[i][j] = "{:{}}".format(toPrint[i][j], maxCol[j])
+      toPrint[i] = toPrint[i][0]+" ".join(toPrint[i][1:])
+    print("\n".join(toPrint), file=output)
 
   def print(self, output, header=False) :
     if header :
diff --git a/Decode.py b/Decode.py
index 300572a..1f48a29 100644
--- a/Decode.py
+++ b/Decode.py
@@ -58,14 +58,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
     while moved :
       features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
       output = torch.nn.functional.softmax(network(features), dim=1)
-      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
-      candidates = [cand[2] for cand in candidates if cand[0]]
+      scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
+      candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
       if len(candidates) == 0 :
         break
-      candidate = candidates[0]
+      candidate = candidates[0][1]
       if debug :
         config.printForDebug(sys.stderr)
-        print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
+        print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
       moved = applyTransition(ts, strat, config, candidate)
 
   EOS.apply(config)
-- 
GitLab