Skip to content
Snippets Groups Projects
Train.py 4.44 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import random
    
    from Transition import Transition, getMissingLinks, applyTransition
    import Features
    
    from Dicts import Dicts
    from Util import timeStamp
    import Networks
    import Decode
    import Config
    
    from conll18_ud_eval import load_conllu, evaluate
    
    ################################################################################
    def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
      transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
      strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
    
      sentences = Config.readConllu(filename)
    
      if type == "oracle" :
        trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
        return
    
      print("ERROR : unknown type '%s'"%type, file=sys.stderr)
      exit(1)
    ################################################################################
    
    
    ################################################################################
    def extractExamples(ts, strat, config, dicts, debug=False) :
      examples = []
    
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      while moved :
        missingLinks = getMissingLinks(config)
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
        if len(candidates) == 0 :
          break
        candidate = candidates[0][1]
        candidateIndex = [trans.name for trans in ts].index(candidate)
        features = Features.extractFeatures(dicts, config)
        example = torch.cat([torch.LongTensor([candidateIndex]), features])
        examples.append(example)
        if debug :
          config.printForDebug(sys.stderr)
          print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
        moved = applyTransition(ts, strat, config, candidate)
    
      EOS.apply(config)
    
      return examples
    ################################################################################
    
    
    ################################################################################
    def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
      examples = []
      dicts = Dicts()
      dicts.readConllu(filename, ["FORM", "UPOS"])
      dicts.save(modelDir+"/dicts.json")
      print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
      for config in sentences :
        examples += extractExamples(transitionSet, strategy, config, dicts, debug)
      print("%s : Extracted %d examples"%(timeStamp(), len(examples)), file=sys.stderr)
      examples = torch.stack(examples)
    
      network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
      optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
      lossFct = torch.nn.CrossEntropyLoss()
    
      bestLoss = None
      bestScore = None
    
      for iter in range(1,nbIter+1) :
    
        network.train()
    
        examples = examples.index_select(0, torch.randperm(examples.size(0)))
        totalLoss = 0.0
        nbEx = 0
        printInterval = 2000
        advancement = 0
        for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
          batch = examples[batchIndex:batchIndex+batchSize]
          targets = batch[:,:1].view(-1)
          inputs = batch[:,1:]
          nbEx += targets.size(0)
          advancement += targets.size(0)
          if not silent and advancement >= printInterval :
            advancement = 0
            print("Curent epoch %6.2f%%"%(100.0*nbEx/examples.size(0)), end="\r", file=sys.stderr)
          outputs = network(inputs)
          loss = lossFct(outputs, targets)
          network.zero_grad()
          loss.backward()
          optimizer.step()
          totalLoss += float(loss)
        devScore = ""
    
        saved = True if bestLoss is None else totalLoss < bestLoss
        bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
    
        if devFile is not None :
          outFilename = modelDir+"/predicted_dev.conllu"
    
          Decode.decodeMode(debug, devFile, "model", modelDir, network, dicts, open(outFilename, "w"))
    
          res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
    
          UAS = res["UAS"][0].f1
          score = UAS
          saved = True if bestScore is None else score > bestScore
          bestScore = score if bestScore is None else max(bestScore, score)
          devScore = ", Dev : UAS=%.2f"%(UAS)
        if saved :
          torch.save(network, modelDir+"/network.pt")
        print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
    
    ################################################################################