import sys
import random
import torch

from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config

from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
  examples = []

  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    candidateIndex = [trans.name for trans in ts].index(candidate)
    features = Features.extractFeatures(dicts, config)
    example = torch.cat([torch.LongTensor([candidateIndex]), features])
    examples.append(example)
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)

  return examples
################################################################################

################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
  examples = []
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir+"/dicts.json")
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
    examples += extractExamples(transitionSet, strategy, config, dicts, debug)
  print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
  examples = torch.stack(examples)

  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
  network.train()
  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
  lossFct = torch.nn.CrossEntropyLoss()
  bestLoss = None
  bestScore = None
  for iter in range(1,nbIter+1) :
    examples = examples.index_select(0, torch.randperm(examples.size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0
    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
      batch = examples[batchIndex:batchIndex+batchSize]
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)
    devScore = ""
    saved = True if bestLoss is None else totalLoss < bestLoss
    bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
    if devFile is not None :
      outFilename = modelDir+"/predicted_dev.conllu"
      Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
      UAS = res["UAS"][0].f1
      score = UAS
      saved = True if bestScore is None else score > bestScore
      bestScore = score if bestScore is None else max(bestScore, score)
      devScore = ", Dev : UAS=%.2f"%(UAS)
    if saved :
      torch.save(network, modelDir+"/network.pt")
    print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################