Skip to content
Snippets Groups Projects
main.py 5.51 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#! /usr/bin/env python3

import sys
from datetime import datetime
Franck Dary's avatar
Franck Dary committed
import Config
import Decode
import Train
from Transition import Transition
import Networks
from Dicts import Dicts

from conll18_ud_eval import load_conllu, evaluate

import torch
Franck Dary's avatar
Franck Dary committed

################################################################################
def timeStamp() :
  return "[%s]"%datetime.now().strftime("%H:%M:%S")
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    examples = []
    dicts = Dicts()
    dicts.readConllu(filename, ["FORM", "UPOS"])
    dicts.save(modelDir+"/dicts.json")
    print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
    for config in sentences :
      examples += Train.extractExamples(transitionSet, strategy, config, dicts, args.debug)
    print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
    examples = torch.stack(examples)

    network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
    network.train()
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
    lossFct = torch.nn.CrossEntropyLoss()
    for iter in range(1,nbIter+1) :
      examples = examples.index_select(0, torch.randperm(examples.size(0)))
      totalLoss = 0.0
      nbEx = 0
      printInterval = 2000
      advancement = 0
      for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
        batch = examples[batchIndex:batchIndex+batchSize]
        targets = batch[:,:1].view(-1)
        inputs = batch[:,1:]
        nbEx += targets.size(0)
        advancement += targets.size(0)
        if not silent and advancement >= printInterval :
          advancement = 0
          print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
        outputs = network(inputs)
        loss = lossFct(outputs, targets)
        network.zero_grad()
        loss.backward()
        optimizer.step()
        totalLoss += float(loss)
      devScore = ""
      if devFile is not None :
        outFilename = modelDir+"/predicted_dev.conllu"
        decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
        devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
      print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def decodeMode(debug, filename, type, network=None, dicts=None, output=sys.stdout) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type in ["random", "oracle"] :
    decodeFunc = Decode.oracleDecode if type == "oracle" else Decode.randomDecode
    for config in sentences :
      decodeFunc(transitionSet, strategy, config, args.debug)
    sentences[0].print(sys.stdout, header=True)
    for config in sentences[1:] :
      config.print(sys.stdout, header=False)
  elif type == "model" :
    for config in sentences :
      Decode.decodeModel(transitionSet, strategy, config, network, dicts, args.debug)
    sentences[0].print(output, header=True)
    for config in sentences[1:] :
      config.print(output, header=False)
  else :
    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
    exit(1)
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("mode", type=str,
    help="What to do : train | decode")
  parser.add_argument("type", type=str,
    help="Type of train or decode. random | oracle")
  parser.add_argument("corpus", type=str,
    help="Name of the CoNLL-U file. Train file for train mode and input file for decode mode.")
  parser.add_argument("model", type=str,
    help="Path to the model directory.")
  parser.add_argument("--iter", "-n", default=5,
    help="Number of training epoch.")
  parser.add_argument("--batchSize", default=64,
    help="Size of each batch.")
  parser.add_argument("--dev", default=None,
    help="Name of the CoNLL-U file of the dev corpus.")
  parser.add_argument("--debug", "-d", default=False, action="store_true",
    help="Print debug infos on stderr.")
  parser.add_argument("--silent", "-s", default=False, action="store_true",
    help="Don't print advancement infos.")
Franck Dary's avatar
Franck Dary committed

  os.makedirs(args.model, exist_ok=True)
Franck Dary's avatar
Franck Dary committed

  if args.mode == "train" :
    trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
  elif args.mode == "decode" :
    decodeMode(args.debug, args.corpus, args.type)
  else :
    print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
    exit(1)
Franck Dary's avatar
Franck Dary committed
################################################################################