Skip to content
Snippets Groups Projects
Train.py 4 KiB
Newer Older
import sys
import random
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
  examples = []

  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    candidateIndex = [trans.name for trans in ts].index(candidate)
    features = Features.extractFeatures(dicts, config)
    example = torch.cat([torch.LongTensor([candidateIndex]), features])
    examples.append(example)
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)

  return examples
################################################################################

################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
  examples = []
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir+"/dicts.json")
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
    examples += extractExamples(transitionSet, strategy, config, dicts, debug)
  print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
  examples = torch.stack(examples)

  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
  network.train()
  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
  lossFct = torch.nn.CrossEntropyLoss()
  for iter in range(1,nbIter+1) :
    examples = examples.index_select(0, torch.randperm(examples.size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0
    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
      batch = examples[batchIndex:batchIndex+batchSize]
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)
    devScore = ""
    if devFile is not None :
      outFilename = modelDir+"/predicted_dev.conllu"
      Decode.decodeMode(debug, devFile, "model", network, dicts, open(outFilename, "w"))
      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
      devScore = ", Dev : UAS=%.2f"%(res["UAS"][0].f1)
    print("%s : Epoch %d, loss=%.2f%s"%(timeStamp(), iter, totalLoss, devScore), file=sys.stderr)
################################################################################