#! /usr/bin/env python3

import sys

if len(sys.argv) != 2 :
  print("USAGE : %s trace.txt"%sys.argv[0], file=sys.stderr)
  exit(1)

backSize = 1 # Script only works with one back action at the moment

nbActionsPerState = {} # for each state, a dict of actionName -> nbOccurences
nbBacks = {} # dict of backName -> nbOccurences
nbBacksUndoError = {} # dict of backName -> nb times it has undone at least 1 err
nbNoBacks = {} # dict of NOBACK -> nbOccurences
nbNoBacksUndoError = {} # dict of NOBACK -> nb times BACK would have undone at least an error if BACK was always chosen
nbCorrectCorrect = 0
nbErrCorrect = 0
nbErrErr = 0
nbCorrectErr = 0

curState = None
curStack = None
curHistory = None
curHistoryPop = None
curScores = None
curAction = None
curCosts = None
curCost = None
curNbUndone = None
curUndone = None

for line in open(sys.argv[1], "r") :
  line = line.strip()

  # End of sentence :
  if len(line) == 0 and curHistoryPop is not None :
    curState = None
    curStack = None
    curHistory = None
    curHistoryPop = None
    curScores = None
    curAction = None
    curCosts = None
    curCost = None
    curNbUndone = None
    curUndone = None

  # Collect info on current line :
  if "state :" in line :
    curState = int(line.split(':')[-1].strip())
  elif "stack :" in line :
    curStack = ["".join([c for c in a if c.isdigit()]) for a in line.split(':')[-1].strip()[1:-2].split(',')]
    curStack = [int(a) for a in curStack if len(a) > 0]
  elif "historyPop" in line :
    curHistoryPop = ":".join(line.replace("'","").split(':')[1:]).split(')')
    curHistoryPop = [a.split('(')[-1] for a in curHistoryPop if len(a.split(',')) > 1]
    if len(curHistoryPop) > 0 :
      curHistoryPop = [(a.split(',')[0].strip(),int(a.split(',')[3].strip().split(':')[-1])) for a in curHistoryPop]
  elif "nbUndone :" in line :
    curNbUndone = int(line.split(':')[1].strip())
  elif "history" in line :
    curHistory = ["".join([c for c in a.strip() if c != "'"]) for a in line.split(':')[-1].strip()[1:-2].split(',')]
  elif "*" in line :
    curScores = line.split()
    for i in range(len(curScores))[::-1] :
      if len(curScores[i].split(':')) == 1 :
        curScores[i-1] = " ".join(curScores[i-1:i+1])
    curScores = [a.replace("*","").split(':') for a in curScores if not len(a.split(':')) == 1]
    curScores = [(float(a[0]), a[1]) for a in curScores]
  elif "Chosen action :" in line :
    curAction = line.split(':')[-1].strip()
  elif "Oracle costs :" in line :
    curCosts = line.split(':')[-1].strip().split('[')
    curCosts = [a[:-1].replace("'","").replace(']','').split(',') for a in curCosts if ',' in a]
    curCosts = [(int(a[0]), a[1].strip()) for a in curCosts]
    curCost = None if "BACK" in curAction else [c[0] for c in curCosts if c[1] == curAction][0]
    # End of action choice :
    # Count actions
    if curState not in nbActionsPerState :
      nbActionsPerState[curState] = {}
    if curAction not in nbActionsPerState[curState] :
      nbActionsPerState[curState][curAction] = 0
    nbActionsPerState[curState][curAction] += 1

    if curUndone is not None and len(curUndone) > 0 and curNbUndone > 0 and "NOBACK" not in curAction and "BACK" not in curAction :
      prevCost = curUndone[0]
      curUndone = curUndone[1:]
      if prevCost == 0 and curCost == 0 :
        nbCorrectCorrect += 1
      elif prevCost == 0 and curCost != 0 :
        nbCorrectErr += 1
      elif prevCost != 0 and curCost != 0 :
        nbErrErr += 1
      elif prevCost != 0 and curCost == 0 :
        nbErrCorrect += 1

    if "NOBACK" in curAction and len([a for a in curHistoryPop if a[0] == "NOBACK"]) >= backSize and "BACK" not in curHistory[-1] :
      if curAction not in nbNoBacks :
        nbNoBacks[curAction] = 0
      nbNoBacks[curAction] += 1
      size = backSize
      nbErrors = 0
      for a in curHistoryPop[::-1] :
        if a[0] == "NOBACK" :
          size -= 1
          if size == 0 :
            break
          continue
        if a[1] < 0 :
          nbErrors += 1
      if curAction not in nbNoBacksUndoError :
        nbNoBacksUndoError[curAction] = 0
      if nbErrors > 0 :
        nbNoBacksUndoError[curAction] += 1
    elif "BACK" in curAction and "NOBACK" not in curAction :
      if curAction not in nbBacks :
        nbBacks[curAction] = 0
      nbBacks[curAction] += 1
      size = int(curAction.split()[-1].strip())
      if size != backSize :
        raise Exception("backSize is wrong")
      nbErrors = 0
      for a in curHistoryPop[::-1] :
        if a[0] == "NOBACK" :
          size -= 1
          if size == 0 :
            break
          continue
        if curUndone is None :
          curUndone = []
        curUndone = [a[1]] + curUndone
        if a[1] < 0 :
          nbErrors += 1
      if curAction not in nbBacksUndoError :
        nbBacksUndoError[curAction] = 0
      if nbErrors > 0 :
        nbBacksUndoError[curAction] += 1


# Printing for each states, number of occurrences of each actions
print("Occurrences of actions :")
for state in nbActionsPerState :
  print("State", state, ":")
  print("  %d\ttotal"%(sum(list(nbActionsPerState[state].values()))))
  actions = sorted([[nbActionsPerState[state][action],action] for action in nbActionsPerState[state]])[::-1]
  actions = ["  %d\t%s"%(a[0],a[1]) for a in actions]
  print("\n".join(actions))

# Answering the question of whether or not the backs are triggered to undo errors
# We compare the number of times a back has undone at least 1 bad action
# with the number of times it would have been the case if we always did back.
print("\nAbout triggering of back actions :")
for action in nbBacks :
  total = nbBacks[action]
  undoErr = nbBacksUndoError[action]
  perc = "%5.2f%%"%(100.0*undoErr/total)
  print(action)
  print("  %s (%d/%d)\tof them canceled a bad action"%(perc, undoErr, total))
  total += nbNoBacks["NOBACK"]
  undoErr += nbNoBacksUndoError["NOBACK"]
  perc = "%5.2f%%"%(100.0*undoErr/total)
  print("  %s (%d/%d)\tif it was always chosen"%(perc, undoErr, total))

print("\nAbout error correction after a BACK :")
totalRedo = nbErrErr + nbErrCorrect + nbCorrectErr + nbCorrectCorrect
print("  %5.2f%% (%d/%d)\ttransformed Error into Error"%(100.0*nbErrErr/totalRedo, nbErrErr, totalRedo))
print("  %5.2f%% (%d/%d)\ttransformed Correct into Correct"%(100.0*nbCorrectCorrect/totalRedo, nbCorrectCorrect, totalRedo))
print("  %5.2f%% (%d/%d)\ttransformed Correct into Error"%(100.0*nbCorrectErr/totalRedo, nbCorrectErr, totalRedo))
print("  %5.2f%% (%d/%d)\ttransformed Error into Correct"%(100.0*nbErrCorrect/totalRedo, nbErrCorrect, totalRedo))