diff --git a/readTrace.py b/readTrace.py index 37510742d26cda3d3a56847421829c8fca64e925..9b998c87c8246b56b36e00729fa9f5635348ac57 100755 --- a/readTrace.py +++ b/readTrace.py @@ -3,6 +3,19 @@ import sys import argparse +backerState = None +################################################################################ +def setBackerState(value) : + global backerState + backerState = value +################################################################################ + +################################################################################ +def getBackerState() : + global backerState + return backerState +################################################################################ + ################################################################################ def lenLine() : return 40 @@ -16,13 +29,17 @@ def englobStr(s, symbol, totalLen) : ################################################################################ ################################################################################ -def alignAsTab(lines) : - withCols = [line.split() for line in lines] - colLens = [max(list(map(len,[withCols[i][j] for i in range(len(withCols))]))) for j in range(len(withCols[0]))] - for line in withCols : - for col in range(len(line)) : - line[col] = line[col] + " "*(colLens[col]-len(line[col])) - return [" ".join(line) for line in withCols] +def isBack(action) : + return "BACK" in action and "NOBACK" not in action +################################################################################ + +################################################################################ +def simple(action) : + if "TAG" in action : + return action.split()[-1] + elif "RIGHT" in action or "LEFT" in action : + return action.split()[0] + return action ################################################################################ ################################################################################ @@ -44,12 +61,7 @@ class Step() : #------------------------------------------------------------------------------- def __str__(self) : - def simple(action) : - if "TAG" in action : - return action.split()[-1] - elif "RIGHT" in action or "LEFT" in action : - return action.split()[0] - return action + action = "'%s'"%simple(self.action) if self.actionCost > self.oracleCost : action += "->" + "'%s'"%simple(self.oracleAction) +\ @@ -62,8 +74,10 @@ class Step() : class Block() : #------------------------------------------------------------------------------- def __init__(self, state) : - self.versions = [[]] # List of list of steps self.state = state + self.versions = [] # List of list of steps + self.stats = [] # For each version, dict of stats + self.newVersion() #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- @@ -74,6 +88,11 @@ class Block() : #------------------------------------------------------------------------------- def newVersion(self) : self.versions.append([]) + self.stats.append({ + "nbErr" : 0, + "avgDist" : 0.0, + "avgIndex" : 0.0, + }) #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- @@ -85,9 +104,13 @@ class Block() : def getAsLines(self, maxNbVersions) : output = [] versions = [] - for version in self.versions : + for v in range(len(self.versions)) : + version = self.versions[v] + stats = self.stats[v] versions.append([]) - lineStr = englobStr("State %d"%self.state, "-", lenLine()) + statsStr = "%derr dist=%.2f index=%.2f"%(stats["nbErr"], + stats["avgDist"], stats["avgIndex"]) + lineStr = englobStr("%s"%(statsStr), "-", lenLine()) versions[-1].append(lineStr + (lenLine()-len(lineStr))*" ") for step in version : versions[-1].append(str(step) + (lenLine()-len(str(step)))*" ") @@ -126,13 +149,13 @@ class History() : block = sentences[-1][1][blockIndex] block.addStep(step) lastState = step.state - if "BACK" in step.action and "NOBACK" not in step.action : + if isBack(step.action) : backSize = int(step.action.split()[-1]) - backState = step.state + setBackerState(step.state) while backSize > 0 : blockIndex -= 1 state = sentences[-1][1][blockIndex].state - if state == backState : + if state == getBackerState() : backSize -= 1 for block in sentences[-1][1][blockIndex:] : block.newVersion() @@ -141,18 +164,71 @@ class History() : #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- - def computeValues(self) : + def computeStats(self) : + globalStats = { + "nbErr" : 0, + "nbErrFound" : 0, + "nbBack" : 0, + "backOnErr" : 0, + "precision" : 0.0, + "recall" : 0.0, + "fScore" : 0.0, + } + for sentence in self.sentences : + for block in sentence[1] : + for i in range(len(block.versions)) : + version = block.versions[i] + stats = block.stats[i] + for step in version : + step.distance = abs(step.actionScore-step.oracleScore) + step.oracleIndex = [a[1] for a in step.scores].index(step.oracleAction) + if step.actionCost > step.oracleCost : + stats["nbErr"] += 1 + stats["avgDist"] += step.distance + stats["avgIndex"] += step.oracleIndex + if i == 0 : + globalStats["nbErr"] += stats["nbErr"] + if len(block.versions) > 1 : + globalStats["nbErrFound"] += stats["nbErr"] + if stats["nbErr"] > 0 : + stats["avgDist"] /= stats["nbErr"] + stats["avgIndex"] /= stats["nbErr"] + for sentence in self.sentences : - for step in sentence[1] : - step.distance = abs(step.actionScore-step.oracleScore) - step.oracleIndex = [a[1] for a in step.scores].index(step.oracleAction) + b = 0 + while b in range(len(sentence[1])) : + block = sentence[1][b] + if block.state != getBackerState() or not isBack(block.versions[0][0].action) : + b += 1 + continue + backSize = int(block.versions[0][0].action.split()[1]) + globalStats["nbBack"] += 1 + backOnErr = False + oldB = b + b -= 1 + while b in range(len(sentence[1])) and backSize > 0 : + if sentence[1][b].stats[0]["nbErr"] > 0 : + backOnErr = True + b -= 1 + if sentence[1][b].state == getBackerState() : + backSize -= 1 + b = oldB + 1 + if backOnErr : + globalStats["backOnErr"] += 1 + if globalStats["nbErr"] > 0 : + globalStats["recall"] = globalStats["nbErrFound"] / globalStats["nbErr"] + if globalStats["nbBack"] > 0 : + globalStats["precision"] = globalStats["backOnErr"] / globalStats["nbBack"] + if globalStats["precision"] + globalStats["recall"] > 0.0 : + globalStats["fScore"] = 2*(globalStats["precision"] * globalStats["recall"])/(globalStats["precision"] + globalStats["recall"]) + return globalStats #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- def printHumanReadable(self, out) : for sentIndex in range(len(self.sentences)) : sentence = self.sentences[sentIndex][1] - annotations = alignAsTab([self.sentences[sentIndex][0][wid] for wid in sorted(list(self.sentences[sentIndex][0].keys()))]) + annotations = [self.sentences[sentIndex][0][wid] for wid in sorted(list(self.sentences[sentIndex][0].keys()))] maxNbVersions = max([block.nbVersions() for block in sentence]) print(englobStr("Sentence %d"%sentIndex, "-", (1+maxNbVersions)*(1+lenLine())), file=out) totalOutput = [] @@ -231,10 +307,12 @@ if __name__ == "__main__" : history = History() history.readFromTrace(args.trace) - history.computeValues() - history.segmentInBlocks() + stats = history.computeStats() + history.printHumanReadable(sys.stdout) + + print(stats) ################################################################################