From d36e1f08b59a8fb6572d9cb731ed3302385ac62f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 24 Apr 2021 16:20:25 +0200 Subject: [PATCH] Improved debug prints --- Config.py | 22 +++++++++++++++++----- Decode.py | 8 ++++---- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/Config.py b/Config.py index b3f3e6c..9e208f9 100644 --- a/Config.py +++ b/Config.py @@ -75,11 +75,18 @@ class Config : return len(self.lines) def printForDebug(self, output) : + printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"] + left = 5 + right = 5 print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) - for lineIndex in range(len(self.lines)) : - print("%s"%("=>" if lineIndex == self.wordIndex else " "), end="", file=output) - toPrint = [] + toPrint = [] + for lineIndex in range(self.wordIndex-left, self.wordIndex+right) : + if lineIndex not in range(len(self.lines)) : + continue + toPrint.append(["%s"%("=>" if lineIndex == self.wordIndex else " ")]) for colIndex in range(len(self.lines[lineIndex])) : + if self.index2col[colIndex] not in printedCols : + continue value = str(self.getAsFeature(lineIndex, self.index2col[colIndex])) if value == "" : value = "_" @@ -87,8 +94,13 @@ class Config : value = self.getAsFeature(int(value), "ID") elif self.index2col[colIndex] == "HEAD" and value == "-1": value = "0" - toPrint.append(value) - print("\t".join(toPrint), file=output) + toPrint[-1].append(value) + maxCol = [max([len(toPrint[i][j]) for i in range(len(toPrint))]) for j in range(len(toPrint[0]))] + for i in range(len(toPrint)) : + for j in range(len(toPrint[i])) : + toPrint[i][j] = "{:{}}".format(toPrint[i][j], maxCol[j]) + toPrint[i] = toPrint[i][0]+" ".join(toPrint[i][1:]) + print("\n".join(toPrint), file=output) def print(self, output, header=False) : if header : diff --git a/Decode.py b/Decode.py index 300572a..1f48a29 100644 --- a/Decode.py +++ b/Decode.py @@ -58,14 +58,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) : while moved : features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = torch.nn.functional.softmax(network(features), dim=1) - candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1] - candidates = [cand[2] for cand in candidates if cand[0]] + scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] + candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] if len(candidates) == 0 : break - candidate = candidates[0] + candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) - print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) + print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) -- GitLab