Skip to content
Snippets Groups Projects
Commit d36e1f08 authored by Franck Dary's avatar Franck Dary
Browse files

Improved debug prints

parent 7ff12b50
No related branches found
No related tags found
No related merge requests found
...@@ -75,11 +75,18 @@ class Config : ...@@ -75,11 +75,18 @@ class Config :
return len(self.lines) return len(self.lines)
def printForDebug(self, output) : def printForDebug(self, output) :
printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
left = 5
right = 5
print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output) print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
for lineIndex in range(len(self.lines)) :
print("%s"%("=>" if lineIndex == self.wordIndex else " "), end="", file=output)
toPrint = [] toPrint = []
for lineIndex in range(self.wordIndex-left, self.wordIndex+right) :
if lineIndex not in range(len(self.lines)) :
continue
toPrint.append(["%s"%("=>" if lineIndex == self.wordIndex else " ")])
for colIndex in range(len(self.lines[lineIndex])) : for colIndex in range(len(self.lines[lineIndex])) :
if self.index2col[colIndex] not in printedCols :
continue
value = str(self.getAsFeature(lineIndex, self.index2col[colIndex])) value = str(self.getAsFeature(lineIndex, self.index2col[colIndex]))
if value == "" : if value == "" :
value = "_" value = "_"
...@@ -87,8 +94,13 @@ class Config : ...@@ -87,8 +94,13 @@ class Config :
value = self.getAsFeature(int(value), "ID") value = self.getAsFeature(int(value), "ID")
elif self.index2col[colIndex] == "HEAD" and value == "-1": elif self.index2col[colIndex] == "HEAD" and value == "-1":
value = "0" value = "0"
toPrint.append(value) toPrint[-1].append(value)
print("\t".join(toPrint), file=output) maxCol = [max([len(toPrint[i][j]) for i in range(len(toPrint))]) for j in range(len(toPrint[0]))]
for i in range(len(toPrint)) :
for j in range(len(toPrint[i])) :
toPrint[i][j] = "{:{}}".format(toPrint[i][j], maxCol[j])
toPrint[i] = toPrint[i][0]+" ".join(toPrint[i][1:])
print("\n".join(toPrint), file=output)
def print(self, output, header=False) : def print(self, output, header=False) :
if header : if header :
......
...@@ -58,14 +58,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) : ...@@ -58,14 +58,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
while moved : while moved :
features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
output = torch.nn.functional.softmax(network(features), dim=1) output = torch.nn.functional.softmax(network(features), dim=1)
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1] scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
if len(candidates) == 0 : if len(candidates) == 0 :
break break
candidate = candidates[0] candidate = candidates[0][1]
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config) EOS.apply(config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment