Skip to content
Snippets Groups Projects
Select Git revision
  • 54bbc20f5399352c98a59bcbc9e7e2ea8be67962
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

test_ExecClassifMonoView.py

Blame
  • Train.py 12.78 KiB
    import sys
    import random
    import torch
    import copy
    import math
    
    from Transition import Transition, getMissingLinks, applyTransition
    import Features
    from Dicts import Dicts
    from Util import timeStamp, prettyInt, numParameters, getDevice
    from Rl import ReplayMemory, selectAction, optimizeModel, rewarding
    import Networks
    import Decode
    import Config
    
    from conll18_ud_eval import load_conllu, evaluate
    
    ################################################################################
    def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
      sentences = Config.readConllu(filename, predicted)
    
      if type == "oracle" :
        trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
        return
    
      if type == "rl":
        trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
        return
    
      print("ERROR : unknown type '%s'"%type, file=sys.stderr)
      exit(1)
    ################################################################################
    
    ################################################################################
    # Return list of examples for each transitionSet
    def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
      examples = [[] for _ in transitionSets]
      with torch.no_grad() :
        EOS = Transition("EOS")
        config.moveWordIndex(0)
        config.state = 0
        moved = True
        while moved :
          ts = transitionSets[config.state]
          missingLinks = getMissingLinks(config)
          candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and trans.name != "BACK"])
          if len(candidates) == 0 :
            break
          best = min([cand[0] for cand in candidates])
          candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
          features = network.extractFeatures(dicts, config)
          candidate = candidateOracle
          if debug :
            config.printForDebug(sys.stderr)
            print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
          if dynamic :
            network.setState(config.state)
            output = network(features.unsqueeze(0).to(getDevice()))
            scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
            candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
            if debug :
              print(str(candidate), file=sys.stderr)
    
          goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
          example = torch.cat([torch.LongTensor([goldIndex]), features])
          examples[config.state].append(example)
    
          moved = applyTransition(strat, config, candidate, None)
    
        EOS.apply(config, strat)
    
      return examples
    ################################################################################
    
    ################################################################################
    def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
      col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
    
      devScore = ""
      saved = True if bestLoss is None else totalLoss < bestLoss
      bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
      if devFile is not None :
        outFilename = modelDir+"/predicted_dev.conllu"
        Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
        res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
        toEval = sorted([col for col in predicted])
        scores = [res[col2metric[col]][0].f1 for col in toEval]
        score = sum(scores)/len(scores)
        saved = True if bestScore is None else score > bestScore
        bestScore = score if bestScore is None else max(bestScore, score)
        devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[toEval[i]], scores[i]) for i in range(len(toEval))])
      if saved :
        torch.save(model, modelDir+"/network.pt")
      for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
        print("{} : Epoch {:{}}/{}, loss={:6.2f}{} {}".format(timeStamp(), epoch, len(str(nbIter)), nbIter, totalLoss, devScore, "SAVED" if saved else ""), file=out)
    
      return bestLoss, bestScore
    ################################################################################
    
    ################################################################################
    def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
      dicts = Dicts()
      dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
      transitionNames = {}
      for ts in transitionSets :
        for t in ts :
          transitionNames[str(t)] = (len(transitionNames), 0)
      transitionNames[dicts.nullToken] = (len(transitionNames), 0)
      dicts.addDict("HISTORY", transitionNames)
    
      dicts.save(modelDir+"/dicts.json")
      network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
      examples = [[] for _ in transitionSets]
      sentences = copy.deepcopy(sentencesOriginal)
      print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
        extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
        for e in range(len(examples)) :
          examples[e] += extracted[e]
    
      totalNbExamples = sum(map(len,examples))
      print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
      for e in range(len(examples)) :
        examples[e] = torch.stack(examples[e])
    
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
      optimizer = torch.optim.Adam(network.parameters(), lr=lr)
      lossFct = torch.nn.CrossEntropyLoss()
      bestLoss = None
      bestScore = None
      for epoch in range(1,nbEpochs+1) :
        if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
          examples = [[] for _ in transitionSets]
          sentences = copy.deepcopy(sentencesOriginal)
          print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
          for config in sentences :
            extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
            for e in range(len(examples)) :
              examples[e] += extracted[e]
          totalNbExamples = sum(map(len,examples))
          print("%s : Extracted %s examples"%(timeStamp(), prettyInt(totalNbExamples, 3)), file=sys.stderr)
          for e in range(len(examples)) :
            examples[e] = torch.stack(examples[e])
    
        network.train()
        for e in range(len(examples)) :
          examples[e] = examples[e].index_select(0, torch.randperm(examples[e].size(0)))
        totalLoss = 0.0
        nbEx = 0
        printInterval = 2000
        advancement = 0
    
        distribution = [len(e)/totalNbExamples for e in examples]
        curIndexes = [0 for _ in examples]
    
        while True :
          state = random.choices(population=range(len(examples)), weights=distribution, k=1)[0]
          if curIndexes[state] >= len(examples[state]) :
            state = -1
            for i in range(len(examples)) :
              if curIndexes[i] < len(examples[i]) :
                state = i
          if state == -1 :
            break
    
          batch = examples[state][curIndexes[state]:curIndexes[state]+batchSize].to(getDevice())
          curIndexes[state] += batchSize
          targets = batch[:,:1].view(-1)
          inputs = batch[:,1:]
          nbEx += targets.size(0)
          advancement += targets.size(0)
          if not silent and advancement >= printInterval :
            advancement = 0
            print("Current epoch %6.2f%%"%(100.0*nbEx/totalNbExamples), end="\r", file=sys.stderr)
    
          network.setState(state)
          outputs = network(inputs)
          loss = lossFct(outputs, targets)
          network.zero_grad()
          loss.backward()
          optimizer.step()
          totalLoss += float(loss)
    
        bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
    ################################################################################
    
    ################################################################################
    def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
    
      memory = None
      dicts = Dicts()
      dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
      transitionNames = {}
      for ts in transitionSets :
        for t in ts :
          transitionNames[str(t)] = (len(transitionNames), 0)
      transitionNames[dicts.nullToken] = (len(transitionNames), 0)
      dicts.addDict("HISTORY", transitionNames)
      dicts.save(modelDir + "/dicts.json")
    
      policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
      target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
      target_net.load_state_dict(policy_net.state_dict())
      target_net.eval()
      policy_net.train()
      optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
      print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
    
      bestLoss = None
      bestScore = None
    
      sentences = copy.deepcopy(sentencesOriginal)
      nbExByEpoch = sum(map(len,sentences))
      sentIndex = 0
    
      for epoch in range(1,nbIter+1) :
        i = 0
        totalLoss = 0.0
        while True :
          if sentIndex >= len(sentences) :
            sentences = copy.deepcopy(sentencesOriginal)
            random.shuffle(sentences)
            sentIndex = 0
    
          if not silent :
            print("Current epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
          sentence = sentences[sentIndex]
          sentence.moveWordIndex(0)
          state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
    
          count = 0
          list_probas = []
          for pb in range(len(probas)):
            list_probas.append([round((probas[pb][0][0]-probas[pb][0][2])*math.exp((-epoch+1)/probas[pb][0][1])+probas[pb][0][2], 2),
                               round((probas[pb][1][0]-probas[pb][1][2])*math.exp((-epoch+1)/probas[pb][1][1])+probas[pb][1][2], 2)])
    
          while True :
            missingLinks = getMissingLinks(sentence)
            transitionSet = transitionSets[sentence.state]
            fromState = sentence.state
            toState = sentence.state
            probaRandom = list_probas[fromState][0]
            probaOracle = list_probas[fromState][1]
            
    
            if debug :
              sentence.printForDebug(sys.stderr)
            action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
    
            if action is None :
              break
    
            if debug :
              print("Selected action : %s"%str(action), file=sys.stderr)
    
            appliable = action.appliable(sentence)
    
            reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
            reward = torch.FloatTensor([reward_]).to(getDevice())
    
            newState = None
            toState = strategy[fromState][action.name][1] if action.name in strategy[fromState] else -1
            if appliable :
              applyTransition(strategy, sentence, action, reward_)
              newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
            else:
              count+=1
    
            if memory is None :
              memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
            memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
            state = newState
            if i % batchSize == 0 :
              totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
              if i % (1*batchSize) == 0 :
                target_net.load_state_dict(policy_net.state_dict())
                target_net.eval()
                policy_net.train()
            i += 1
    
            if state is None or count == countBreak:
              break
          if i >= nbExByEpoch :
            break
          sentIndex += 1
        bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
    ################################################################################