Skip to content
Snippets Groups Projects
Commit d36e1f08 authored by Franck Dary's avatar Franck Dary
Browse files

Improved debug prints

parent 7ff12b50
No related branches found
No related tags found
No related merge requests found
......@@ -75,11 +75,18 @@ class Config :
return len(self.lines)
def printForDebug(self, output) :
printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
left = 5
right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
for lineIndex in range(len(self.lines)) :
print("%s"%("=>" if lineIndex == self.wordIndex else " "), end="", file=output)
toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) :
continue
toPrint.append(["%s"%("=>" if lineIndex == self.wordIndex else " ")])
for colIndex in range(len(self.lines[lineIndex])) :
if self.index2col[colIndex] not in printedCols :
continue
value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
if value == "" :
value = "_"
......@@ -87,8 +94,13 @@ class Config :
value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0"
toPrint.append(value)
print("\t".join(toPrint), file=output)
toPrint[-1].append(value)
maxCol = [max([len(toPrint[i][j]) for i in range(len(toPrint))]) for j in range(len(toPrint[0]))]
for i in range(len(toPrint)) :
for j in range(len(toPrint[i])) :
toPrint[i][j] = "{:{}}".format(toPrint[i][j], maxCol[j])
toPrint[i] = toPrint[i][0]+" ".join(toPrint[i][1:])
print("\n".join(toPrint), file=output)
def print(self, output, header=False) :
if header :
......
......@@ -58,14 +58,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
while moved :
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
if len(candidates) == 0 :
break
candidate = candidates[0]
candidate = candidates[0][1]
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment