import sys
import random
import torch
import copy
import math

from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks
import Decode
import Config

from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
  sentences = Config.readConllu(filename, predicted)

  if type == "oracle" :
    trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
    return

  if type == "rl":
    trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
# Return list of examples for each transitionSet
def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
  examples = [[] for _ in transitionSets]
  with torch.no_grad() :
    EOS = Transition("EOS")
    config.moveWordIndex(0)
    config.state = 0
    moved = True
    while moved :
      ts = transitionSets[config.state]
      missingLinks = getMissingLinks(config)
      candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
      if len(candidates) == 0 :
        break
      best = min([cand[0] for cand in candidates])
      candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
      features = network.extractFeatures(dicts, config)
      candidate = candidateOracle
      if debug :
        config.printForDebug(sys.stderr)
        print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
      if dynamic :
        network.setState(config.state)
        output = network(features.unsqueeze(0).to(getDevice()))
        scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
        candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
        if debug :
          print(str(candidate), file=sys.stderr)

      goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
      example = torch.cat([torch.LongTensor([goldIndex]), features])
      examples[config.state].append(example)

      moved = applyTransition(strat, config, candidate, None)

    EOS.apply(config, strat)

  return examples
################################################################################

################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
  col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}

  devScore = ""
  saved = True if bestLoss is None else totalLoss < bestLoss
  bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
  if devFile is not None :
    outFilename = modelDir+"/predicted_dev.conllu"
    Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
    res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
    toEval = sorted([col for col in predicted])
    scores = [res[col2metric[col]][0].f1 for col in toEval]
    score = sum(scores)/len(scores)
    saved = True if bestScore is None else score > bestScore
    bestScore = score if bestScore is None else max(bestScore, score)
    devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
  if saved :
    torch.save(model, modelDir+"/network.pt")
  for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
    print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)

  return bestLoss, bestScore
################################################################################

################################################################################
def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
  transitionNames = {}
  for ts in transitionSets :
    for t in ts :
      transitionNames[str(t)] = (len(transitionNames), 0)
  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
  dicts.addDict("HISTORY", transitionNames)

  dicts.save(modelDir+"/dicts.json")
  network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  examples = [[] for _ in transitionSets]
  sentences = copy.deepcopy(sentencesOriginal)
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
    extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
    for e in range(len(examples)) :
      examples[e] += extracted[e]

  totalNbExamples = sum(map(len,examples))
  print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
  for e in range(len(examples)) :
    examples[e] = torch.stack(examples[e])

  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
  optimizer = torch.optim.Adam(network.parameters(), lr=lr)
  lossFct = torch.nn.CrossEntropyLoss()
  bestLoss = None
  bestScore = None
  for epoch in range(1,nbEpochs+1) :
    if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
      examples = [[] for _ in transitionSets]
      sentences = copy.deepcopy(sentencesOriginal)
      print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
        extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
        for e in range(len(examples)) :
          examples[e] += extracted[e]
      totalNbExamples = sum(map(len,examples))
      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
      for e in range(len(examples)) :
        examples[e] = torch.stack(examples[e])

    network.train()
    for e in range(len(examples)) :
      examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0

    distribution = [len(e)/totalNbExamples for e in examples]
    curIndexes = [0 for _ in examples]

    while True :
      state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
      if curIndexes[state] >= len(examples[state]) :
        state = -1
        for i in range(len(examples)) :
          if curIndexes[i] < len(examples[i]) :
            state = i
      if state == -1 :
        break

      batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
      curIndexes[state] += batchSize
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)

      network.setState(state)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)

    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
################################################################################

################################################################################
def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :

  memory = None
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
  transitionNames = {}
  for ts in transitionSets :
    for t in ts :
      transitionNames[str(t)] = (len(transitionNames), 0)
  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
  dicts.addDict("HISTORY", transitionNames)
  dicts.save(modelDir + "/dicts.json")

  policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  target_net.load_state_dict(policy_net.state_dict())
  target_net.eval()
  policy_net.train()
  optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)

  bestLoss = None
  bestScore = None

  sentences = copy.deepcopy(sentencesOriginal)
  nbExByEpoch = sum(map(len,sentences))
  sentIndex = 0

  for epoch in range(1,nbIter+1) :
    probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
    probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
    i = 0
    totalLoss = 0.0
    while True :
      if sentIndex >= len(sentences) :
        sentences = copy.deepcopy(sentencesOriginal)
        random.shuffle(sentences)
        sentIndex = 0

      if not silent :
        print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
      sentence = sentences[sentIndex]
      sentence.moveWordIndex(0)
      state = policy_net.extractFeatures(dicts, sentence).to(getDevice())

      count = 0

      while True :
        missingLinks = getMissingLinks(sentence)
        transitionSet = transitionSets[sentence.state]
        fromState = sentence.state
        toState = sentence.state

        if debug :
          sentence.printForDebug(sys.stderr)
        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)

        if action is None :
          break

        if debug :
          print("Selected action : %s"%str(action), file=sys.stderr)

        appliable = action.appliable(sentence)

        reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
        reward = torch.FloatTensor([reward_]).to(getDevice())

        newState = None
        toState = strategy[action.name][1] if action.name in strategy else -1
        if appliable :
          applyTransition(strategy, sentence, action, reward_)
          newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
        else:
          count+=1

        if memory is None :
          memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
        memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
        state = newState
        if i % batchSize == 0 :
          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
          if i % (1*batchSize) == 0 :
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()
            policy_net.train()
        i += 1

        if state is None or count == countBreak:
          break
      if i >= nbExByEpoch :
        break
      sentIndex += 1
    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
################################################################################