Skip to content
Snippets Groups Projects
Select Git revision
  • 8565bf2d254a55bfdef0fa770c2485b19731b32f
  • master default protected
  • erased
  • states
  • negatives
  • temp
  • negativeExamples
  • Rl
8 results

readTrace.py

Blame
  • readTrace.py 19.42 KiB
    #! /usr/bin/env python3
    
    import sys
    import argparse
    
    backerState = None
    ################################################################################
    def setBackerState(value) :
      global backerState
      backerState = value
    ################################################################################
    
    ################################################################################
    def getBackerState() :
      global backerState
      return backerState
    ################################################################################
    
    ################################################################################
    def lenLine() :
      return 40
    ################################################################################
    
    ################################################################################
    def englobStr(s, symbol, totalLen) :
      s = " %s "%s
      df = totalLen - len(s)
      return "%s%s%s"%(symbol*(df//2),s,symbol*(df-2*df//2+df//2))
    ################################################################################
    
    ################################################################################
    def isBack(action) :
      return "BACK" in action and "NOBACK" not in action
    ################################################################################
    
    ################################################################################
    def isParser(action) :
      return "SHIFT" in action or "REDUCE" in action or "LEFT" in action or "RIGHT" in action
    ################################################################################
    
    ################################################################################
    def simple(action) :
      if "TAG" in action :
        return action.split()[-1]
      elif "RIGHT" in action or "LEFT" in action :
        return action.split()[0]
      return action
    ################################################################################
    
    ################################################################################
    class Step() :
    #-------------------------------------------------------------------------------
      def __init__(self) :
        self.state = None
        self.action = None
        self.scores = None
        self.costs = None
        self.oracleAction = None
        self.actionScore = None
        self.actionCost = None
        self.oracleScore = None
        self.stack = None
        self.historyPop = None
        self.history = None
        self.word = None
    
        self.distance = 0
        self.oracleIndex = 0
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def __str__(self) :
        action = " ".join(["%.2f@%s"%(c[0],simple(c[1])) for c in self.scores[:args.nbScores]])
        if self.actionCost > self.oracleCost :
          action = "%s CORR(%s)"%(action, simple(self.oracleAction))
        return action
    #-------------------------------------------------------------------------------
    ################################################################################
    
    ################################################################################
    class Block() :
    #-------------------------------------------------------------------------------
      def __init__(self, state) :
        self.state = state
        self.versions = [] # List of list of steps
        self.stats = [] # For each version, dict of stats
        self.newVersion()
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def addStep(self, step) :
        self.versions[-1].append(step)
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def newVersion(self) :
        self.versions.append([])
        self.stats.append({
          "nbErr" : 0,
          "avgDist" : 0.0,
          "avgIndex" : 0.0,
          "maxIndex" : 0.0,
          "maxDist" : 0.0,
          })
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def nbVersions(self) :
        return len(self.versions)
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def getAsLines(self, maxNbVersions) :
        output = []
        versions = []
        for v in range(len(self.versions)) :
          version = self.versions[v]
          stats = self.stats[v]
          versions.append([])
          englobChar = "-"
          if len(version) > 0 and version[0].actionCost > version[0].oracleCost :
            englobChar = "~"
          lineStr = englobStr("" if len(version) == 0 else version[0].word, englobChar, lenLine())
          versions[-1].append(lineStr + (lenLine()-len(lineStr))*" ")
          for step in version :
            versions[-1].append(str(step) + (lenLine()-len(str(step)))*" ")
        maxIndex = max([len(version) for version in versions])
        for i in range(maxIndex) :
          output.append("")
          for j in range(maxNbVersions) :
            output[-1] += ("\t" if j > 0 else "") + (versions[j][i] if j in range(len(versions)) and i in range(len(versions[j])) else lenLine()*" ")
    
        return output
    #-------------------------------------------------------------------------------
    ################################################################################
    
    ################################################################################
    class History() :
    #-------------------------------------------------------------------------------
      def __init__(self) :
        self.sentences = []
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def segmentInBlocks(self) :
        #structure : sentence = [annotations,list of blocks]
        sentences = []
        for sentenceAnnot in self.sentences :
          annot = sentenceAnnot[0]
          sentence = sentenceAnnot[1]
          lastState = None
          sentences.append([annot, []])
          blockIndex = 0
          for step in sentence :
            if lastState is not None and lastState != step.state :
              blockIndex += 1
            if blockIndex >= len(sentences[-1][1]) :
              sentences[-1][1].append(Block(step.state))
            block = sentences[-1][1][blockIndex]
            block.addStep(step)
            lastState = step.state
            if isBack(step.action) :
              backSize = int(step.action.split()[-1])
              setBackerState(step.state)
              while backSize > 0 :
                blockIndex -= 1
                state = sentences[-1][1][blockIndex].state
                if state == getBackerState() :
                  backSize -= 1
              for block in sentences[-1][1][blockIndex:] :
                block.newVersion()
    
        self.sentences = sentences
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def computeStats(self) :
        globalStats = {
          "nbWords" : 0,
          "nbArcs" : 0,
          "nbMissedArcs" : 0,
          "nbActions" : 0,
          "nbActionsNormal" : 0,
          "nbActionsParser" : 0,
          "nbErr" : 0,
          "nbErrParser" : 0,
          "avgErrCost" : 0,
          "avgErrCostParser" : 0,
          "nbErrFound" : 0,
          "nbBack" : 0,
          "backOnErr" : 0,
          "actionAccuracy" : 0,
          "actionAccuracyParser" : 0,
          "arcsAccuracy" : 0,
          "backPrecision" : 0.0,
          "backRecall" : 0.0,
          "backFScore" : 0.0,
          "nbRedone" : 0,
          "nbRedoneDiminishErr" : 0,
          "nbRedoneAugmentErr" : 0,
          "redoneAvgErrChange" : 0,
          "nbRedoneErrErr" : 0,
          "nbRedoneErrCorrect" : 0,
          "nbRedoneCorrectErr" : 0,
          "nbRedoneCorrectCorrect" : 0,
          "redoneErrErrAvgDistChange" : 0.0,
          "redoneErrErrAvgIndexChange" : 0.0,
        }
        for sentence in self.sentences :
          globalStats["nbWords"] += len(sentence[0])
          globalStats["nbArcs"] += len(sentence[0]) - 1
          for block in sentence[1] :
            
            for i in range(len(block.versions)) :
              version = block.versions[i]
              stats = block.stats[i]
              if i == 0 :
                globalStats["nbActions"] += len(version)
              if block.state == getBackerState() :
                continue
              for step in version :
                step.distance = abs(step.actionScore-step.oracleScore)
                step.oracleIndex = [a[1] for a in step.scores].index(step.oracleAction)
                if i == 0 :
                  globalStats["avgErrCost"] += step.actionCost
                  if isParser(step.action) :
                    globalStats["avgErrCostParser"] += step.actionCost
                    globalStats["nbMissedArcs"] += step.actionCost
                    globalStats["nbActionsParser"] += 1
                    if step.actionCost > step.oracleCost :
                      globalStats["nbErrParser"] += 1
                if step.actionCost > step.oracleCost :
                  stats["nbErr"] += 1
                  stats["avgDist"] += step.distance
                  stats["avgIndex"] += step.oracleIndex
                  stats["maxDist"] = max(stats["maxDist"], step.distance)
                  stats["maxIndex"] = max(stats["maxIndex"], step.oracleIndex)
              if i == 0 :
                globalStats["nbActionsNormal"] += len(version)
                globalStats["nbErr"] += stats["nbErr"]
                if len(block.versions) > 1 :
                  globalStats["nbErrFound"] += stats["nbErr"]
              if i == 1 :
                prevStats = block.stats[i-1]
                globalStats["nbRedone"] += 1
                distChange = prevStats["maxDist"] - stats["maxDist"]
                indexChange = prevStats["maxIndex"] - stats["maxIndex"]
                if prevStats["nbErr"] > 0 and  stats["nbErr"] > 0 :
                  globalStats["nbRedoneErrErr"] += 1
                  globalStats["redoneErrErrAvgDistChange"] += distChange
                  globalStats["redoneErrErrAvgIndexChange"] += indexChange
                if prevStats["nbErr"] == 0 and  stats["nbErr"] > 0 :
                  globalStats["nbRedoneCorrectErr"] += 1
                if prevStats["nbErr"] == 0 and  stats["nbErr"] == 0 :
                  globalStats["nbRedoneCorrectCorrect"] += 1
                if prevStats["nbErr"] > 0 and  stats["nbErr"] == 0 :
                  globalStats["nbRedoneErrCorrect"] += 1
                if prevStats["nbErr"] > stats["nbErr"] :
                  globalStats["nbRedoneDiminishErr"] += 1
                if prevStats["nbErr"] < stats["nbErr"] :
                  globalStats["nbRedoneAugmentErr"] += 1
                globalStats["redoneAvgErrChange"] += stats["nbErr"] - prevStats["nbErr"]
              if stats["nbErr"] > 0 :
                stats["avgDist"] /= stats["nbErr"]
                stats["avgIndex"] /= stats["nbErr"]
    
        for sentence in self.sentences :
          b = 0
          while b in range(len(sentence[1])) :
            block = sentence[1][b]
            if block.state != getBackerState() or not isBack(block.versions[0][0].action) :
              b += 1
              continue
            backSize = int(block.versions[0][0].action.split()[1])
            globalStats["nbBack"] += 1
            backOnErr = False
            oldB = b
            b -= 1
            while b in range(len(sentence[1])) and backSize > 0 :
              if sentence[1][b].stats[0]["nbErr"] > 0 :
                backOnErr = True
              b -= 1
              if sentence[1][b].state == getBackerState() :
                backSize -= 1
            b = oldB + 1
            if backOnErr :
              globalStats["backOnErr"] += 1
        if globalStats["nbActionsNormal"] > 0 :
          globalStats["actionAccuracy"] = 100.0*(globalStats["nbActionsNormal"]-globalStats["nbErr"])/globalStats["nbActionsNormal"]
        if globalStats["nbActionsParser"] > 0 :
          globalStats["actionAccuracyParser"] = 100.0*(globalStats["nbActionsParser"]-globalStats["nbErrParser"])/globalStats["nbActionsParser"]
        if globalStats["nbArcs"] > 0 :
          globalStats["arcsAccuracy"] = 100.0*(globalStats["nbArcs"]-globalStats["nbMissedArcs"])/globalStats["nbArcs"]
        if globalStats["nbErr"] > 0 :
          globalStats["avgErrCost"] /= globalStats["nbErr"]
        if globalStats["nbErrParser"] > 0 :
          globalStats["avgErrCostParser"] /= globalStats["nbErrParser"]
        if globalStats["nbErr"] > 0 :
          globalStats["backRecall"] = 100.0*globalStats["nbErrFound"] / globalStats["nbErr"]
        if globalStats["nbBack"] > 0 :
          globalStats["backPrecision"] = 100.0*globalStats["backOnErr"] / globalStats["nbBack"]
        if globalStats["backPrecision"] + globalStats["backRecall"] > 0.0 :
          globalStats["backFScore"] = 2*(globalStats["backPrecision"] * globalStats["backRecall"])/(globalStats["backPrecision"] + globalStats["backRecall"])
        if globalStats["nbRedoneErrErr"] :
          globalStats["redoneErrErrAvgDistChange"] /= globalStats["nbRedoneErrErr"]
          globalStats["redoneErrErrAvgIndexChange"] /= globalStats["nbRedoneErrErr"]
        if globalStats["nbRedone"] :
          globalStats["redoneAvgErrChange"] /= globalStats["nbRedone"]
          globalStats["nbRedoneDiminishErr"] /= globalStats["nbRedone"] * (1/100)
          globalStats["nbRedoneAugmentErr"] /= globalStats["nbRedone"] * (1/100)
          globalStats["nbRedoneCorrectCorrect"] /= globalStats["nbRedone"] * (1/100)
          globalStats["nbRedoneErrErr"] /= globalStats["nbRedone"] * (1/100)
          globalStats["nbRedoneCorrectErr"] /= globalStats["nbRedone"] * (1/100)
          globalStats["nbRedoneErrCorrect"] /= globalStats["nbRedone"] * (1/100)
    
        return globalStats
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def printHumanReadable(self, out) :
        for sentIndex in range(len(self.sentences)) :
          sentence = self.sentences[sentIndex][1]
          annotations = [self.sentences[sentIndex][0][wid] for wid in sorted(list(self.sentences[sentIndex][0].keys()))]
          maxNbVersions = max([block.nbVersions() for block in sentence])
          print(englobStr("Sentence %d"%sentIndex, "-", (1+maxNbVersions)*(1+lenLine())), file=out)
          totalOutput = []
          for block in sentence :
            totalOutput += block.getAsLines(maxNbVersions)
          for i in range(len(totalOutput)) :
            print(totalOutput[i] + ("\t"+("Output of the machine:" if i == 0 else annotations[i-1]) if i in range(len(annotations)+1) else ""), file=out)
          print("", file=out)
    #-------------------------------------------------------------------------------
    
    #-------------------------------------------------------------------------------
      def readFromTrace(self, traceFile) :
        curStep = Step()
        started = False
    
        for line in open(traceFile, "r") :
          line = line.rstrip()
    
          # End of sentence :
          if len(line) == 0 :
            if len(self.sentences) == 0 or len(self.sentences[-1]) > 0 :
              self.sentences.append([])
              self.sentences[-1].append({})
              self.sentences[-1].append([])
            continue
    
          if "-----" in line :
            started = True
          if not started :
            continue
    
          if "state :" in line :
            curStep.state = int(line.split(':')[-1].strip())
          elif "=>" in line :
            annotLine = line.split("=>")[-1]
            curId = int(annotLine.split()[0])
            curStep.word = annotLine.split()[args.formIndex]
            self.sentences[-1][0][curId] = annotLine
          elif "stack :" in line :
            curStep.stack = ["".join([c for c in a if c.isdigit()]) for a in line.split(':')[-1].strip()[1:-2].split(',')]
            curStep.stack = [int(a) for a in curStep.stack if len(a) > 0]
          elif "historyPop" in line :
            curStep.historyPop = ":".join(line.replace("'","").split(':')[1:]).split(')')
            curStep.historyPop = [a.split('(')[-1] for a in curStep.historyPop if len(a.split(',')) > 1]
            if len(curStep.historyPop) > 0 :
              curStep.historyPop = [(a.split(',')[0].strip(),int(a.split(',')[3].strip().split(':')[-1])) for a in curStep.historyPop]
          elif "history" in line :
            curStep.history = ["".join([c for c in a.strip() if c != "'"]) for a in line.split(':')[-1].strip()[1:-2].split(',')]
          elif "*" in line :
            curStep.scores = line.split()
            for i in range(len(curStep.scores))[::-1] :
              if len(curStep.scores[i].split(':')) == 1 :
                curStep.scores[i-1] = " ".join(curStep.scores[i-1:i+1])
            curStep.scores = [a.replace("*","").split(':') for a in curStep.scores if not len(a.split(':')) == 1]
            curStep.scores = [(float(a[0]), a[1]) for a in curStep.scores]
          elif "  " in line :
            annotLine = " ".join(line.split("  ")[1:])
            if "-" not in annotLine.split()[0] :
              curId = int(annotLine.split()[0])
              self.sentences[-1][0][curId] = annotLine
          elif "Chosen action :" in line :
            curStep.action = line.split(':')[-1].strip()
          elif "Oracle costs :" in line :
            curStep.costs = line.split(':')[-1].strip().split('[')
            curStep.costs = [a[:-1].replace("'","").replace(']','').split(',') for a in curStep.costs if ',' in a]
            curStep.costs = [(int(a[0]), a[1].strip()) for a in curStep.costs]
            curStep.actionCost = 0 if "BACK" in curStep.action else [c[0] for c in curStep.costs if c[1] == curStep.action][0]
            curStep.oracleCost = min([b[0] for b in curStep.costs])
            curStep.oracleAction = [a[1] for a in curStep.costs if a[0] == curStep.oracleCost][0]
            curStep.oracleScore = [a[0] for a in curStep.scores if a[1] == curStep.oracleAction][0]
            curStep.actionScore = [a[0] for a in curStep.scores if a[1] == curStep.action][0]
    
            self.sentences[-1][-1].append(curStep)
            curStep = Step()
    #-------------------------------------------------------------------------------
    ################################################################################
    
    ################################################################################
    def prettyNumber(num) :
      base = "%.2f"%num
      splited = base.split('.')
      striped = splited[1].rstrip('.0')
      if len(striped) > 0 :
        striped = "."+striped
      return splited[0] + striped
    ################################################################################
    
    ################################################################################
    if __name__ == "__main__" :
      parser = argparse.ArgumentParser()
      parser.add_argument("traces", nargs="+", default=[],
        help="File produced by debug mode (-d) of the decoding.")
      parser.add_argument("--steps", default=False, action="store_true",
        help="Print all decoding steps.")
      parser.add_argument("--stats", default=False, action="store_true",
        help="Print global stats about the decoding.")
      parser.add_argument("--formIndex", default=1,
        help="Index of the form of words in the trace file.")
      parser.add_argument("--nbScores", default=2,
        help="Number of action scores displayed in --steps mode.")
      args = parser.parse_args()
    
      if not (args.steps or args.stats) :
        print("ERROR: must provide --steps or --stats", file=sys.stderr)
        exit(1)
    
      histories = []
      stats = []
      for trace in args.traces :
        histories.append(History())
        histories[-1].readFromTrace(trace)
        histories[-1].segmentInBlocks()
        stats.append(histories[-1].computeStats())
    
      if args.steps :
        for i in range(len(args.traces)) :
          print("History of '%s' :\n"%args.traces[i])
          histories[i].printHumanReadable(sys.stdout)
    
      if args.stats :
        asList = [["Filename"]+list(stats[0].keys())]
        for i in range(len(args.traces)) :
          asList.append([args.traces[i]]+list(map(prettyNumber, list(stats[i].values()))))
        maxLens = [max(map(len, asList[i])) for i in range(len(asList))]
        for i in range(len(asList[0])) :
          for j in range(len(asList)) :
            sep = "." if j == 0 else " "
            print("%s"%(asList[j][i]+sep*(1+maxLens[j]-len(asList[j][i]))), end=" ")
          print("")
    ################################################################################