Skip to content
Snippets Groups Projects
Train.py 12.2 KiB
Newer Older
import sys
import random
Franck Dary's avatar
Franck Dary committed
import copy
from Transition import Transition, getMissingLinks, applyTransition
import Features
from Dicts import Dicts
from Util import timeStamp, prettyInt, numParameters, getDevice
Maxime Petit's avatar
Maxime Petit committed
from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
import Networks
import Decode
import Config
from conll18_ud_eval import load_conllu, evaluate

################################################################################
def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
  sentences = Config.readConllu(filename, predicted)
    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
Maxime Petit's avatar
Maxime Petit committed
  if type == "rl":
    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
Maxime Petit's avatar
Maxime Petit committed
    return

  print("ERROR : unknown type '%s'"%type, file=sys.stderr)
  exit(1)
################################################################################

################################################################################
Franck Dary's avatar
Franck Dary committed
# Return list of examples for each transitionSet
def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
  examples = [[] for _ in transitionSets]
  with torch.no_grad() :
    EOS = Transition("EOS")
    config.moveWordIndex(0)
Franck Dary's avatar
Franck Dary committed
    config.state = 0
    moved = True
    while moved :
Franck Dary's avatar
Franck Dary committed
      ts = transitionSets[config.state]
      missingLinks = getMissingLinks(config)
      candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
      if len(candidates) == 0 :
        break
      best = min([cand[0] for cand in candidates])
      candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
      features = network.extractFeatures(dicts, config)
      candidate = candidateOracle
      if debug :
        config.printForDebug(sys.stderr)
        print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
        network.setState(config.state)
        output = network(features.unsqueeze(0).to(getDevice()))
        scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
        candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
        if debug :
          print(str(candidate), file=sys.stderr)
      goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
      example = torch.cat([torch.LongTensor([goldIndex]), features])
Franck Dary's avatar
Franck Dary committed
      examples[config.state].append(example)
      moved = applyTransition(strat, config, candidate, None)
    EOS.apply(config, strat)
  return examples
################################################################################

################################################################################
def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
  col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}

  devScore = ""
  saved = True if bestLoss is None else totalLoss < bestLoss
  bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
  if devFile is not None :
    outFilename = modelDir+"/predicted_dev.conllu"
    Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
    res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
    toEval = sorted([col for col in predicted])
    scores = [res[col2metric[col]][0].f1 for col in toEval]
    score = sum(scores)/len(scores)
    saved = True if bestScore is None else score > bestScore
    bestScore = score if bestScore is None else max(bestScore, score)
    devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
  if saved :
    torch.save(model, modelDir+"/network.pt")
  for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
Franck Dary's avatar
Franck Dary committed
    print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)

  return bestLoss, bestScore
################################################################################

################################################################################
Franck Dary's avatar
Franck Dary committed
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
  dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
Franck Dary's avatar
Franck Dary committed
  transitionNames = {}
  for ts in transitionSets :
    for t in ts :
      transitionNames[str(t)] = (len(transitionNames), 0)
  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
  dicts.addDict("HISTORY", transitionNames)

  dicts.save(modelDir+"/dicts.json")
Franck Dary's avatar
Franck Dary committed
  network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  examples = [[] for _ in transitionSets]
  sentences = copy.deepcopy(sentencesOriginal)
  print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
  for config in sentences :
Franck Dary's avatar
Franck Dary committed
    extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
    for e in range(len(examples)) :
      examples[e] += extracted[e]

  totalNbExamples = sum(map(len,examples))
  print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
Franck Dary's avatar
Franck Dary committed
  for e in range(len(examples)) :
    examples[e] = torch.stack(examples[e])
  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
  optimizer = torch.optim.Adam(network.parameters(), lr=lr)
  lossFct = torch.nn.CrossEntropyLoss()
  bestLoss = None
  bestScore = None
  for epoch in range(1,nbEpochs+1) :
    if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
      examples = [[] for _ in transitionSets]
      sentences = copy.deepcopy(sentencesOriginal)
      print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
Franck Dary's avatar
Franck Dary committed
        extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
        for e in range(len(examples)) :
          examples[e] += extracted[e]
      totalNbExamples = sum(map(len,examples))
      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
Franck Dary's avatar
Franck Dary committed
      for e in range(len(examples)) :
        examples[e] = torch.stack(examples[e])
    network.train()
    for e in range(len(examples)) :
      examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
    totalLoss = 0.0
    nbEx = 0
    printInterval = 2000
    advancement = 0

    distribution = [len(e)/totalNbExamples for e in examples]
    curIndexes = [0 for _ in examples]

    while True :
      state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
      if curIndexes[state] >= len(examples[state]) :
        state = -1
        for i in range(len(examples)) :
          if curIndexes[i] < len(examples[i]) :
            state = i
      if state == -1 :
        break

      batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
      curIndexes[state] += batchSize
      targets = batch[:,:1].view(-1)
      inputs = batch[:,1:]
      nbEx += targets.size(0)
      advancement += targets.size(0)
      if not silent and advancement >= printInterval :
        advancement = 0
        print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)

      network.setState(state)
      outputs = network(inputs)
      loss = lossFct(outputs, targets)
      network.zero_grad()
      loss.backward()
      optimizer.step()
      totalLoss += float(loss)
Franck Dary's avatar
Franck Dary committed
    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
Franck Dary's avatar
Franck Dary committed
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  memory = None
Maxime Petit's avatar
Maxime Petit committed
  dicts = Dicts()
  dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
Franck Dary's avatar
Franck Dary committed
  transitionNames = {}
  for ts in transitionSets :
    for t in ts :
      transitionNames[str(t)] = (len(transitionNames), 0)
  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
  dicts.addDict("HISTORY", transitionNames)
Maxime Petit's avatar
Maxime Petit committed
  dicts.save(modelDir + "/dicts.json")

Franck Dary's avatar
Franck Dary committed
  policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
  target_net.load_state_dict(policy_net.state_dict())
  target_net.eval()
  policy_net.train()
  optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
Maxime Petit's avatar
Maxime Petit committed

  bestLoss = None
  bestScore = None

  sentences = copy.deepcopy(sentencesOriginal)
  nbExByEpoch = sum(map(len,sentences))
  sentIndex = 0

  for epoch in range(1,nbIter+1) :
    probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
    probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
Franck Dary's avatar
Franck Dary committed
    i = 0
    totalLoss = 0.0
    while True :
      if sentIndex >= len(sentences) :
        sentences = copy.deepcopy(sentencesOriginal)
        random.shuffle(sentences)
        sentIndex = 0

Franck Dary's avatar
Franck Dary committed
      if not silent :
Franck Dary's avatar
Franck Dary committed
        print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
Franck Dary's avatar
Franck Dary committed
      sentence = sentences[sentIndex]
      sentence.moveWordIndex(0)
      state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
Franck Dary's avatar
Franck Dary committed
      while True :
        missingLinks = getMissingLinks(sentence)
Franck Dary's avatar
Franck Dary committed
        transitionSet = transitionSets[sentence.state]
Franck Dary's avatar
Franck Dary committed
        if debug :
          sentence.printForDebug(sys.stderr)
        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
Franck Dary's avatar
Franck Dary committed
        if action is None :
          break

          print("Selected action : %s"%str(action), file=sys.stderr)
        appliable = action.appliable(sentence)

        reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
Maxime Petit's avatar
Maxime Petit committed
        reward = torch.FloatTensor([reward_]).to(getDevice())
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
        newState = None
        if appliable :
          applyTransition(strategy, sentence, action, reward_)
          newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
        if memory is None :
          memory = ReplayMemory(5000, state.numel())
        memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
Franck Dary's avatar
Franck Dary committed
        state = newState
        if i % batchSize == 0 :
          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
Franck Dary's avatar
Franck Dary committed
          if i % (1*batchSize) == 0 :
Franck Dary's avatar
Franck Dary committed
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()
            policy_net.train()
        i += 1
        if state is None or count == countBreak:
          break
      if i >= nbExByEpoch :
        break
      sentIndex += 1
Franck Dary's avatar
Franck Dary committed
    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
Franck Dary's avatar
Franck Dary committed
################################################################################