Skip to content
Snippets Groups Projects
Train.py 7.63 KiB
Newer Older
import sys
import random
Franck Dary's avatar
Franck Dary committed
import copy
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp
Franck Dary's avatar
Franck Dary committed
from Rl import ReplayMemory, selectAction, optimizeModel
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type == "oracle" :
    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

Maxime Petit's avatar
Maxime Petit committed
  if type == "rl":
    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
  examples = []

  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    candidateIndex = [trans.name for trans in ts].index(candidate)
    features = Features.extractFeatures(dicts, config)
    example = torch.cat([torch.LongTensor([candidateIndex]), features])
    examples.append(example)
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)

  return examples
################################################################################

################################################################################
def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
  examples = []
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir+"/dicts.json")
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
    examples += extractExamples(transitionSet, strategy, config, dicts, debug)
  print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
  examples = torch.stack(examples)

  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
  lossFct = torch.nn.CrossEntropyLoss()
  bestLoss = None
  bestScore = None
  for iter in range(1,nbIter+1) :
    network.train()
    examples = examples.index_select(0, torch.randperm(examples.size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0
    for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
      batch = examples[batchIndex:batchIndex+batchSize]
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)
    devScore = ""
    saved = True if bestLoss is None else totalLoss < bestLoss
    bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
    if devFile is not None :
      outFilename = modelDir+"/predicted_dev.conllu"
      Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
      UAS = res["UAS"][0].f1
      score = UAS
      saved = True if bestScore is None else score > bestScore
      bestScore = score if bestScore is None else max(bestScore, score)
      devScore = ", Dev : UAS=%.2f"%(UAS)
    if saved :
      torch.save(network, modelDir+"/network.pt")
    print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :

Maxime Petit's avatar
Maxime Petit committed
  memory = ReplayMemory(1000)
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM", "UPOS"])
  dicts.save(modelDir + "/dicts.json")

  policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
  target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
  target_net.load_state_dict(policy_net.state_dict())
  target_net.eval()
Franck Dary's avatar
Franck Dary committed
  policy_net.train()
Maxime Petit's avatar
Maxime Petit committed

  optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
  bestLoss = None
  bestScore = None

Franck Dary's avatar
Franck Dary committed
  for epoch in range(nbIter) :
    i = 0
    totalLoss = 0.0
    sentences = copy.deepcopy(sentencesOriginal)
    for sentIndex in range(len(sentences)) :
      if not silent :
        print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
      sentence = sentences[sentIndex]
      sentence.moveWordIndex(0)
      state = Features.extractFeaturesPosExtended(dicts, sentence)
      while True :
        missingLinks = getMissingLinks(sentence)
        if debug :
          sentence.printForDebug(sys.stderr)
        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
        if action is None :
          break

        reward = -1.0*action.getOracleScore(sentence, missingLinks)
        reward = torch.FloatTensor([reward])

        applyTransition(transitionSet, strategy, sentence, action.name)
Maxime Petit's avatar
Maxime Petit committed
        newState = Features.extractFeaturesPosExtended(dicts, sentence)

Franck Dary's avatar
Franck Dary committed
        memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
        state = newState
        if i % batchSize == 0 :
          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
          if i % (2*batchSize) == 0 :
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()
            policy_net.train()
        i += 1
    # Fin epoch, compute score and save model
    devScore = ""
    saved = True if bestLoss is None else totalLoss < bestLoss
    bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
    if devFile is not None :
      outFilename = modelDir+"/predicted_dev.conllu"
      Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
      UAS = res["UAS"][0].f1
      score = UAS
      saved = True if bestScore is None else score > bestScore
      bestScore = score if bestScore is None else max(bestScore, score)
      devScore = ", Dev : UAS=%.2f"%(UAS)
    if saved :
      torch.save(policy_net, modelDir+"/network.pt")
    print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)

################################################################################