import sys
import random
import torch
import copy

from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice
from Rl import ReplayMemory, selectAction, optimizeModel
import Networks
import Decode
import Config

from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

  if type == "rl":
    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
  examples = []

  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    candidateIndex = [trans.name for trans in ts].index(candidate)
    features = Features.extractFeatures(dicts, config)
    example = torch.cat([torch.LongTensor([candidateIndex]), features])
    examples.append(example)
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)

  return examples
################################################################################

################################################################################
def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter) :
  devScore = ""
  saved = True if bestLoss is None else totalLoss < bestLoss
  bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
  if devFile is not None :
    outFilename = modelDir+"/predicted_dev.conllu"
    Decode.decodeMode(debug, devFile, "model", modelDir, model, dicts, open(outFilename, "w"))
    res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
    UAS = res["UAS"][0].f1
    score = UAS
    saved = True if bestScore is None else score > bestScore
    bestScore = score if bestScore is None else max(bestScore, score)
    devScore = ", Dev : UAS=%.2f"%(UAS)
  if saved :
    torch.save(model, modelDir+"/network.pt")
  print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)

  return bestLoss, bestScore
################################################################################

################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
  examples = []
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir+"/dicts.json")
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
    examples += extractExamples(transitionSet, strategy, config, dicts, debug)
  print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
  examples = torch.stack(examples)

  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
  lossFct = torch.nn.CrossEntropyLoss()
  bestLoss = None
  bestScore = None
  for epoch in range(1,nbEpochs+1) :
    network.train()
    examples = examples.index_select(0, torch.randperm(examples.size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0
    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
      batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)

    bestLoss, bestScore = evalModelAndSave(debug, network, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs)
################################################################################

################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :

  memory = None
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir + "/dicts.json")

  policy_net = None
  target_net = None
  optimizer = None

  bestLoss = None
  bestScore = None

  for epoch in range(1,nbIter+1) :
    i = 0
    totalLoss = 0.0
    sentences = copy.deepcopy(sentencesOriginal)
    for sentIndex in range(len(sentences)) :
      if not silent :
        print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
      sentence = sentences[sentIndex]
      sentence.moveWordIndex(0)
      state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())

      if policy_net is None :
        policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
        target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
        target_net.load_state_dict(policy_net.state_dict())
        target_net.eval()
        policy_net.train()
        optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
        print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)

      while True :
        missingLinks = getMissingLinks(sentence)
        if debug :
          sentence.printForDebug(sys.stderr)
        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
        if action is None :
          break

        reward = -1.0*action.getOracleScore(sentence, missingLinks)
        reward = torch.FloatTensor([reward]).to(getDevice())

        applyTransition(transitionSet, strategy, sentence, action.name)
        newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())

        if memory is None :
          memory = ReplayMemory(1000, state.numel())
        memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
        state = newState
        if i % batchSize == 0 :
          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
          if i % (2*batchSize) == 0 :
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()
            policy_net.train()
        i += 1
    bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
################################################################################