Skip to content
Snippets Groups Projects
Commit 48e9375b authored by Franck Dary's avatar Franck Dary
Browse files

Improved readability in mode --steps of script readTrace.py

parent 6a6cd177
No related branches found
No related tags found
No related merge requests found
...@@ -62,6 +62,7 @@ class Step() : ...@@ -62,6 +62,7 @@ class Step() :
self.stack = None self.stack = None
self.historyPop = None self.historyPop = None
self.history = None self.history = None
self.word = None
self.distance = 0 self.distance = 0
self.oracleIndex = 0 self.oracleIndex = 0
...@@ -69,11 +70,9 @@ class Step() : ...@@ -69,11 +70,9 @@ class Step() :
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
def __str__(self) : def __str__(self) :
action = " ".join(["%.2f@%s"%(c[0],simple(c[1])) for c in self.scores[:args.nbScores]])
action = "'%s'"%simple(self.action)
if self.actionCost > self.oracleCost : if self.actionCost > self.oracleCost :
action += "->" + "'%s'"%simple(self.oracleAction) +\ action = "%s CORR(%s)"%(action, simple(self.oracleAction))
"(dist=%.2f index=%d)"%(self.distance, self.oracleIndex)
return action return action
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
################################################################################ ################################################################################
...@@ -118,9 +117,10 @@ class Block() : ...@@ -118,9 +117,10 @@ class Block() :
version = self.versions[v] version = self.versions[v]
stats = self.stats[v] stats = self.stats[v]
versions.append([]) versions.append([])
statsStr = "%derr dist=%.2f index=%.2f"%(stats["nbErr"], englobChar = "-"
stats["avgDist"], stats["avgIndex"]) if version[0].actionCost > version[0].oracleCost :
lineStr = englobStr("%s"%(statsStr), "-", lenLine()) englobChar = "~"
lineStr = englobStr(version[0].word, englobChar, lenLine())
versions[-1].append(lineStr + (lenLine()-len(lineStr))*" ") versions[-1].append(lineStr + (lenLine()-len(lineStr))*" ")
for step in version : for step in version :
versions[-1].append(str(step) + (lenLine()-len(str(step)))*" ") versions[-1].append(str(step) + (lenLine()-len(str(step)))*" ")
...@@ -327,7 +327,7 @@ class History() : ...@@ -327,7 +327,7 @@ class History() :
for block in sentence : for block in sentence :
totalOutput += block.getAsLines(maxNbVersions) totalOutput += block.getAsLines(maxNbVersions)
for i in range(len(totalOutput)) : for i in range(len(totalOutput)) :
print(totalOutput[i] + ("\t"+annotations[i] if i in range(len(annotations)) else ""), file=out) print(totalOutput[i] + ("\t"+("Output of the machine:" if i == 0 else annotations[i-1]) if i in range(len(annotations)+1) else ""), file=out)
print("", file=out) print("", file=out)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
...@@ -357,6 +357,7 @@ class History() : ...@@ -357,6 +357,7 @@ class History() :
elif "=>" in line : elif "=>" in line :
annotLine = line.split("=>")[-1] annotLine = line.split("=>")[-1]
curId = int(annotLine.split()[0]) curId = int(annotLine.split()[0])
curStep.word = annotLine.split()[args.formIndex]
self.sentences[-1][0][curId] = annotLine self.sentences[-1][0][curId] = annotLine
elif "stack :" in line : elif "stack :" in line :
curStep.stack = ["".join([c for c in a if c.isdigit()]) for a in line.split(':')[-1].strip()[1:-2].split(',')] curStep.stack = ["".join([c for c in a if c.isdigit()]) for a in line.split(':')[-1].strip()[1:-2].split(',')]
...@@ -416,8 +417,16 @@ if __name__ == "__main__" : ...@@ -416,8 +417,16 @@ if __name__ == "__main__" :
help="Print all decoding steps.") help="Print all decoding steps.")
parser.add_argument("--stats", default=False, action="store_true", parser.add_argument("--stats", default=False, action="store_true",
help="Print global stats about the decoding.") help="Print global stats about the decoding.")
parser.add_argument("--formIndex", default=1,
help="Index of the form of words in the trace file.")
parser.add_argument("--nbScores", default=2,
help="Number of action scores displayed in --steps mode.")
args = parser.parse_args() args = parser.parse_args()
if not (args.steps or args.stats) :
print("ERROR: must provide --steps or --stats", file=sys.stderr)
exit(1)
histories = [] histories = []
stats = [] stats = []
for trace in args.traces : for trace in args.traces :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment