Skip to content
Snippets Groups Projects
Train.py 12.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import random
    
    Franck Dary's avatar
    Franck Dary committed
    import copy
    
    from Transition import Transition, getMissingLinks, applyTransition
    import Features
    
    from Dicts import Dicts
    
    from Util import timeStamp, prettyInt, numParameters, getDevice
    
    Maxime Petit's avatar
    Maxime Petit committed
    from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
    
    import Networks
    import Decode
    import Config
    
    from conll18_ud_eval import load_conllu, evaluate
    
    ################################################################################
    
    def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
    
      sentences = Config.readConllu(filename, predicted)
    
        trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
    
    Maxime Petit's avatar
    Maxime Petit committed
      if type == "rl":
    
        trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
    
    Maxime Petit's avatar
    Maxime Petit committed
        return
    
    
      print("ERROR : unknown type '%s'"%type, file=sys.stderr)
      exit(1)
    ################################################################################
    
    
    ################################################################################
    
    Franck Dary's avatar
    Franck Dary committed
    # Return list of examples for each transitionSet
    def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
      examples = [[] for _ in transitionSets]
    
      with torch.no_grad() :
        EOS = Transition("EOS")
        config.moveWordIndex(0)
    
    Franck Dary's avatar
    Franck Dary committed
        config.state = 0
    
        moved = True
        while moved :
    
    Franck Dary's avatar
    Franck Dary committed
          ts = transitionSets[config.state]
    
          missingLinks = getMissingLinks(config)
    
          candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
    
          if len(candidates) == 0 :
            break
          best = min([cand[0] for cand in candidates])
          candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
    
          features = network.extractFeatures(dicts, config)
    
          candidate = candidateOracle
    
          if debug :
            config.printForDebug(sys.stderr)
    
            print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
    
            network.setState(config.state)
    
            output = network(features.unsqueeze(0).to(getDevice()))
    
            scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
    
            candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
            if debug :
    
              print(str(candidate), file=sys.stderr)
    
          goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
    
          example = torch.cat([torch.LongTensor([goldIndex]), features])
    
    Franck Dary's avatar
    Franck Dary committed
          examples[config.state].append(example)
    
          moved = applyTransition(strat, config, candidate, None)
    
        EOS.apply(config, strat)
    
      return examples
    ################################################################################
    
    
    ################################################################################
    
    def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
      col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
    
    
      devScore = ""
      saved = True if bestLoss is None else totalLoss < bestLoss
      bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
      if devFile is not None :
        outFilename = modelDir+"/predicted_dev.conllu"
    
        Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
    
        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
    
        toEval = sorted([col for col in predicted])
        scores = [res[col2metric[col]][0].f1 for col in toEval]
    
        score = sum(scores)/len(scores)
    
        saved = True if bestScore is None else score > bestScore
        bestScore = score if bestScore is None else max(bestScore, score)
    
        devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
    
      if saved :
        torch.save(model, modelDir+"/network.pt")
    
      for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
    
    Franck Dary's avatar
    Franck Dary committed
        print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
    
    
      return bestLoss, bestScore
    ################################################################################
    
    ################################################################################
    
    def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
    
      dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
    
    Franck Dary's avatar
    Franck Dary committed
      transitionNames = {}
      for ts in transitionSets :
        for t in ts :
          transitionNames[str(t)] = (len(transitionNames), 0)
      transitionNames[dicts.nullToken] = (len(transitionNames), 0)
      dicts.addDict("HISTORY", transitionNames)
    
    
      dicts.save(modelDir+"/dicts.json")
    
      network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
      examples = [[] for _ in transitionSets]
    
      sentences = copy.deepcopy(sentencesOriginal)
    
      print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
    
    Franck Dary's avatar
    Franck Dary committed
        extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
        for e in range(len(examples)) :
          examples[e] += extracted[e]
    
    
      totalNbExamples = sum(map(len,examples))
      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
    
    Franck Dary's avatar
    Franck Dary committed
      for e in range(len(examples)) :
        examples[e] = torch.stack(examples[e])
    
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
    
      optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    
      lossFct = torch.nn.CrossEntropyLoss()
    
      bestLoss = None
      bestScore = None
    
      for epoch in range(1,nbEpochs+1) :
    
        if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
    
          examples = [[] for _ in transitionSets]
    
          sentences = copy.deepcopy(sentencesOriginal)
          print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
          for config in sentences :
    
    Franck Dary's avatar
    Franck Dary committed
            extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
            for e in range(len(examples)) :
              examples[e] += extracted[e]
    
          totalNbExamples = sum(map(len,examples))
          print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
    
    Franck Dary's avatar
    Franck Dary committed
          for e in range(len(examples)) :
            examples[e] = torch.stack(examples[e])
    
        network.train()
    
        for e in range(len(examples)) :
          examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
    
        totalLoss = 0.0
        nbEx = 0
        printInterval = 2000
        advancement = 0
    
    
        distribution = [len(e)/totalNbExamples for e in examples]
        curIndexes = [0 for _ in examples]
    
        while True :
          state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
          if curIndexes[state] >= len(examples[state]) :
            state = -1
            for i in range(len(examples)) :
              if curIndexes[i] < len(examples[i]) :
                state = i
          if state == -1 :
            break
    
          batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
          curIndexes[state] += batchSize
    
          targets = batch[:,:1].view(-1)
          inputs = batch[:,1:]
          nbEx += targets.size(0)
          advancement += targets.size(0)
          if not silent and advancement >= printInterval :
            advancement = 0
    
            print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)
    
          network.setState(state)
    
          outputs = network(inputs)
          loss = lossFct(outputs, targets)
          network.zero_grad()
          loss.backward()
          optimizer.step()
          totalLoss += float(loss)
    
    Franck Dary's avatar
    Franck Dary committed
        bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
    
    ################################################################################
    
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    
    def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      memory = None
    
    Maxime Petit's avatar
    Maxime Petit committed
      dicts = Dicts()
    
      dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
    
    Franck Dary's avatar
    Franck Dary committed
      transitionNames = {}
      for ts in transitionSets :
        for t in ts :
          transitionNames[str(t)] = (len(transitionNames), 0)
      transitionNames[dicts.nullToken] = (len(transitionNames), 0)
      dicts.addDict("HISTORY", transitionNames)
    
    Maxime Petit's avatar
    Maxime Petit committed
      dicts.save(modelDir + "/dicts.json")
    
    
      policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
      target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
    
      target_net.load_state_dict(policy_net.state_dict())
      target_net.eval()
      policy_net.train()
    
      optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
    
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
      bestLoss = None
      bestScore = None
    
    
      sentences = copy.deepcopy(sentencesOriginal)
      nbExByEpoch = sum(map(len,sentences))
      sentIndex = 0
    
    
      for epoch in range(1,nbIter+1) :
    
        probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
        probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
    
    Franck Dary's avatar
    Franck Dary committed
        i = 0
        totalLoss = 0.0
    
        while True :
          if sentIndex >= len(sentences) :
            sentences = copy.deepcopy(sentencesOriginal)
            random.shuffle(sentences)
            sentIndex = 0
    
    
    Franck Dary's avatar
    Franck Dary committed
          if not silent :
    
    Franck Dary's avatar
    Franck Dary committed
            print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
    
    Franck Dary's avatar
    Franck Dary committed
          sentence = sentences[sentIndex]
          sentence.moveWordIndex(0)
    
          state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
          while True :
            missingLinks = getMissingLinks(sentence)
    
    Franck Dary's avatar
    Franck Dary committed
            transitionSet = transitionSets[sentence.state]
    
    Franck Dary's avatar
    Franck Dary committed
            fromState = sentence.state
            toState = sentence.state
    
    
    Franck Dary's avatar
    Franck Dary committed
            if debug :
              sentence.printForDebug(sys.stderr)
    
    Franck Dary's avatar
    Franck Dary committed
            action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
    
    Franck Dary's avatar
    Franck Dary committed
            if action is None :
              break
    
    
              print("Selected action : %s"%str(action), file=sys.stderr)
    
            appliable = action.appliable(sentence)
    
    
            reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
    
    Maxime Petit's avatar
    Maxime Petit committed
            reward = torch.FloatTensor([reward_]).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
            newState = None
    
            if appliable :
    
              applyTransition(strategy, sentence, action, reward_)
    
    Franck Dary's avatar
    Franck Dary committed
              toState = sentence.state
    
              newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
            if memory is None :
    
    Franck Dary's avatar
    Franck Dary committed
              memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
            memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
    
    Franck Dary's avatar
    Franck Dary committed
            state = newState
            if i % batchSize == 0 :
    
              totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
    
    Franck Dary's avatar
    Franck Dary committed
              if i % (1*batchSize) == 0 :
    
    Franck Dary's avatar
    Franck Dary committed
                target_net.load_state_dict(policy_net.state_dict())
                target_net.eval()
                policy_net.train()
            i += 1
    
            if state is None or count == countBreak:
    
              break
          if i >= nbExByEpoch :
            break
          sentIndex += 1
    
    Franck Dary's avatar
    Franck Dary committed
        bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################