Skip to content
Snippets Groups Projects
Train.py 10.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import random
    
    Franck Dary's avatar
    Franck Dary committed
    import copy
    
    from Transition import Transition, getMissingLinks, applyTransition
    import Features
    
    from Dicts import Dicts
    
    from Util import timeStamp, prettyInt, numParameters, getDevice
    
    Maxime Petit's avatar
    Maxime Petit committed
    from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
    
    import Networks
    import Decode
    import Config
    
    from conll18_ud_eval import load_conllu, evaluate
    
    ################################################################################
    
    def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) :
    
      sentences = Config.readConllu(filename)
    
      if type == "oracle" :
    
        trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent)
    
    Maxime Petit's avatar
    Maxime Petit committed
      if type == "rl":
    
        trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent)
    
    Maxime Petit's avatar
    Maxime Petit committed
        return
    
    
      print("ERROR : unknown type '%s'"%type, file=sys.stderr)
      exit(1)
    ################################################################################
    
    
    ################################################################################
    
    def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
    
      with torch.no_grad() :
        EOS = Transition("EOS")
        config.moveWordIndex(0)
        moved = True
        while moved :
          missingLinks = getMissingLinks(config)
    
          candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
    
          if len(candidates) == 0 :
            break
          best = min([cand[0] for cand in candidates])
          candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
    
          features = network.extractFeatures(dicts, config)
    
          candidate = candidateOracle.name
          if debug :
            config.printForDebug(sys.stderr)
            print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
    
            output = network(features.unsqueeze(0).to(getDevice()))
            scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
            candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
            if debug :
              print(candidate.name, file=sys.stderr)
            
          goldIndex = [trans.name for trans in ts].index(candidateOracle.name)
          candidateIndex = [trans.name for trans in ts].index(candidate)
          example = torch.cat([torch.LongTensor([goldIndex]), features])
          examples.append(example)
    
    Franck Dary's avatar
    Franck Dary committed
          moved = applyTransition(ts, strat, config, candidate, None)
    
        EOS.apply(config, strat)
    
      return examples
    ################################################################################
    
    
    ################################################################################
    
    def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) :
    
      devScore = ""
      saved = True if bestLoss is None else totalLoss < bestLoss
      bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
      if devFile is not None :
        outFilename = modelDir+"/predicted_dev.conllu"
    
        Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, modelDir, model, dicts, open(outFilename, "w"))
    
        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
        UAS = res["UAS"][0].f1
        score = UAS
        saved = True if bestScore is None else score > bestScore
        bestScore = score if bestScore is None else max(bestScore, score)
        devScore = ", Dev : UAS=%.2f"%(UAS)
      if saved :
        torch.save(model, modelDir+"/network.pt")
    
      for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
    
    Franck Dary's avatar
    Franck Dary committed
        print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
    
    
      return bestLoss, bestScore
    ################################################################################
    
    ################################################################################
    
    def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
    
      dicts.readConllu(filename, ["FORM","UPOS"], 2)
    
      dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
    
      dicts.save(modelDir+"/dicts.json")
    
      network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
    
      examples = []
      sentences = copy.deepcopy(sentencesOriginal)
    
      print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
    
        examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
    
      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
    
      examples = torch.stack(examples)
    
    
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
    
      optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    
      lossFct = torch.nn.CrossEntropyLoss()
    
      bestLoss = None
      bestScore = None
    
      for epoch in range(1,nbEpochs+1) :
    
        if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
          examples = []
          sentences = copy.deepcopy(sentencesOriginal)
          print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
          for config in sentences :
    
            examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
    
          print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
          examples = torch.stack(examples)
    
    
        network.train()
    
        examples = examples.index_select(0, torch.randperm(examples.size(0)))
        totalLoss = 0.0
        nbEx = 0
        printInterval = 2000
        advancement = 0
        for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
    
          batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
    
          targets = batch[:,:1].view(-1)
          inputs = batch[:,1:]
          nbEx += targets.size(0)
          advancement += targets.size(0)
          if not silent and advancement >= printInterval :
            advancement = 0
    
    Franck Dary's avatar
    Franck Dary committed
            print("Current epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
    
          outputs = network(inputs)
          loss = lossFct(outputs, targets)
          network.zero_grad()
          loss.backward()
          optimizer.step()
          totalLoss += float(loss)
    
        bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc)
    
    ################################################################################
    
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    
    def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) :
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      memory = None
    
    Maxime Petit's avatar
    Maxime Petit committed
      dicts = Dicts()
    
      dicts.readConllu(filename, ["FORM","UPOS"], 2)
    
      dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
    
    Maxime Petit's avatar
    Maxime Petit committed
      dicts.save(modelDir + "/dicts.json")
    
    
      policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
      target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
    
      target_net.load_state_dict(policy_net.state_dict())
      target_net.eval()
      policy_net.train()
    
      optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
    
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
      bestLoss = None
      bestScore = None
    
    
      sentences = copy.deepcopy(sentencesOriginal)
      nbExByEpoch = sum(map(len,sentences))
      sentIndex = 0
    
    
      for epoch in range(1,nbIter+1) :
    
        probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
        probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
    
    Franck Dary's avatar
    Franck Dary committed
        i = 0
        totalLoss = 0.0
    
        while True :
          if sentIndex >= len(sentences) :
            sentences = copy.deepcopy(sentencesOriginal)
            random.shuffle(sentences)
            sentIndex = 0
    
    
    Franck Dary's avatar
    Franck Dary committed
          if not silent :
    
    Franck Dary's avatar
    Franck Dary committed
            print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
    
    Franck Dary's avatar
    Franck Dary committed
          sentence = sentences[sentIndex]
          sentence.moveWordIndex(0)
    
          state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
          while True :
            missingLinks = getMissingLinks(sentence)
            if debug :
              sentence.printForDebug(sys.stderr)
    
            action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
    
    Franck Dary's avatar
    Franck Dary committed
            if action is None :
              break
    
    
            if debug :
              print("Selected action : %s"%action.name, file=sys.stderr)
    
    
            appliable = action.appliable(sentence)
    
    
            reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
    
    Maxime Petit's avatar
    Maxime Petit committed
            reward = torch.FloatTensor([reward_]).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
    
    
            newState = None
            if appliable :
    
    Maxime Petit's avatar
    Maxime Petit committed
              applyTransition(transitionSet, strategy, sentence, action.name, reward_)
    
              newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
            if memory is None :
    
              memory = ReplayMemory(5000, state.numel())
    
            memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
    
    Franck Dary's avatar
    Franck Dary committed
            state = newState
            if i % batchSize == 0 :
    
              totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
    
    Franck Dary's avatar
    Franck Dary committed
              if i % (1*batchSize) == 0 :
    
    Franck Dary's avatar
    Franck Dary committed
                target_net.load_state_dict(policy_net.state_dict())
                target_net.eval()
                policy_net.train()
            i += 1
    
    
            if state is None :
              break
          if i >= nbExByEpoch :
            break
          sentIndex += 1
    
        bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc)
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################