Skip to content
Snippets Groups Projects
Decode.py 4.36 KiB
Newer Older
  • Learn to ignore specific revisions
  • import random
    import sys
    from Transition import Transition, getMissingLinks, applyTransition
    
    from Dicts import Dicts
    
    import torch
    
    ################################################################################
    def randomDecode(ts, strat, config, debug=False) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      while True :
        candidates = [trans for trans in ts if trans.appliable(config)]
        if len(candidates) == 0 :
          break
        candidate = candidates[random.randint(0, 100) % len(candidates)]
        if debug :
          config.printForDebug(sys.stderr)
          print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
    
        applyTransition(strat, config, candidate, 0.)
    
      EOS.apply(config, strat)
    
    ################################################################################
    
    ################################################################################
    def oracleDecode(ts, strat, config, debug=False) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      while moved :
        missingLinks = getMissingLinks(config)
    
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
    
        if len(candidates) == 0 :
          break
        candidate = candidates[0][1]
        if debug :
          config.printForDebug(sys.stderr)
    
          print((" | ".join(["%d '%s'"%(c[0], str(c[1])) for c in candidates]))+"\n"+("-"*80)+"\n", file=sys.stderr)
        moved = applyTransition(strat, config, candidate, 0.)
    
    Franck Dary's avatar
    Franck Dary committed
      EOS.apply(config, strat)
    
    ################################################################################
    
    ################################################################################
    
    def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
    
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      network.eval()
    
    
      currentDevice = network.currentDevice()
      decodeDevice = getDevice()
      network.to(decodeDevice)
    
    
    Maxime Petit's avatar
    Maxime Petit committed
      if debug :
    
        print("\n"+("-"*80), file=sys.stderr)
    
      with torch.no_grad():
        while moved :
    
          features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
    
          output = network(features)
    
          scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
    
    Franck Dary's avatar
    Franck Dary committed
          candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
    
          if len(candidates) == 0 :
            break
    
    Franck Dary's avatar
    Franck Dary committed
          candidate = candidates[0][1]
    
          missingLinks = getMissingLinks(config)
    
          if debug :
            config.printForDebug(sys.stderr)
    
            print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%str(candidate), file=sys.stderr)
    
            candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
    
            print("Oracle costs :"+str([[c[0],str(c[1])] for c in candidates]), file=sys.stderr)
    
          reward = rewarding(True, config, candidate, missingLinks, rewardFunc)
          moved = applyTransition(strat, config, candidate, reward)
    
      EOS.apply(config, strat)
    
    ################################################################################
    
    
    ################################################################################
    
    def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, predicted, modelDir=None, network=None, dicts=None, output=sys.stdout) :
    
      sentences = Config.readConllu(filename, predicted)
    
    
      if type in ["random", "oracle"] :
        decodeFunc = oracleDecode if type == "oracle" else randomDecode
        for config in sentences :
          decodeFunc(transitionSet, strategy, config, debug)
        sentences[0].print(sys.stdout, header=True)
        for config in sentences[1:] :
          config.print(sys.stdout, header=False)
      elif type == "model" :
    
        if dicts is None :
          dicts = Dicts()
          dicts.load(modelDir+"/dicts.json")
          network = torch.load(modelDir+"/network.pt")
    
        for config in sentences :
    
          decodeModel(transitionSet, strategy, config, network, dicts, debug, rewardFunc)
    
        sentences[0].print(output, header=True)
        for config in sentences[1:] :
          config.print(output, header=False)
      else :
        print("ERROR : unknown type '%s'"%type, file=sys.stderr)
        exit(1)
    ################################################################################