#! /usr/bin/env python3

import sys
import os
import argparse
from datetime import datetime

import Config
import Decode
import Train
from Transition import Transition
import Networks
from Dicts import Dicts

from conll18_ud_eval import load_conllu, evaluate

import torch

################################################################################
def timeStamp() :
  return "[%s]"%datetime.now().strftime("%H:%M:%S")
################################################################################

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    examples = []
    dicts = Dicts()
    dicts.readConllu(filename, ["FORM", "UPOS"])
    dicts.save(modelDir+"/dicts.json")
    print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
    for config in sentences :
      examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
    print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
    examples = torch.stack(examples)

    network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
    network.train()
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
    lossFct = torch.nn.CrossEntropyLoss()
    for iter in range(1,nbIter+1) :
      examples = examples.index_select(0, torch.randperm(examples.size(0)))
      totalLoss = 0.0
      nbEx = 0
      printInterval = 2000
      advancement = 0
      for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
        batch = examples[batchIndex:batchIndex+batchSize]
        targets = batch[:,:1].view(-1)
        inputs = batch[:,1:]
        nbEx += targets.size(0)
        advancement += targets.size(0)
        if not silent and advancement >= printInterval :
          advancement = 0
          print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
        outputs = network(inputs)
        loss = lossFct(outputs, targets)
        network.zero_grad()
        loss.backward()
        optimizer.step()
        totalLoss += float(loss)
      devScore = ""
      if devFile is not None :
        outFilename = modelDir+"/predicted_dev.conllu"
        decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
        devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
      print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type in ["random", "oracle"] :
    decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
    for config in sentences :
      decodeFunc(transitionSet, strategy, config, args.debug)
    sentences[0].print(sys.stdout, header=True)
    for config in sentences[1:] :
      config.print(sys.stdout, header=False)
  elif type == "model" :
    for config in sentences :
      Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
    sentences[0].print(output, header=True)
    for config in sentences[1:] :
      config.print(output, header=False)
  else :
    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
    exit(1)
################################################################################

################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("mode", type=str,
    help="What to do : train | decode")
  parser.add_argument("type", type=str,
    help="Type of train or decode. random | oracle")
  parser.add_argument("corpus", type=str,
    help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
  parser.add_argument("model", type=str,
    help="Path to the model directory.")
  parser.add_argument("--iter", "-n", default=5,
    help="Number of training epoch.")
  parser.add_argument("--batchSize", default=64,
    help="Size of each batch.")
  parser.add_argument("--dev", default=None,
    help="Name of the CoNLL-U file of the dev corpus.")
  parser.add_argument("--debug", "-d", default=False, action="store_true",
    help="Print debug infos on stderr.")
  parser.add_argument("--silent", "-s", default=False, action="store_true",
    help="Don't print advancement infos.")
  args = parser.parse_args()

  os.makedirs(args.model, exist_ok=True)

  if args.mode == "train" :
    trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
  elif args.mode == "decode" :
    decodeMode(args.debug, args.corpus, args.type)
  else :
    print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
    exit(1)
################################################################################