Skip to content
Snippets Groups Projects
Decode.py 2.45 KiB
Newer Older
  • Learn to ignore specific revisions
  • import random
    import sys
    from Transition import Transition, getMissingLinks, applyTransition
    from Features import extractFeatures
    import torch
    
    ################################################################################
    def randomDecode(ts, strat, config, debug=False) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      while True :
        candidates = [trans for trans in ts if trans.appliable(config)]
        if len(candidates) == 0 :
          break
        candidate = candidates[random.randint(0, 100) % len(candidates)]
        if debug :
          config.printForDebug(sys.stderr)
          print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
        applyTransition(ts, strat, config, candidate.name)
    
      EOS.apply(config)
    ################################################################################
    
    ################################################################################
    def oracleDecode(ts, strat, config, debug=False) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      while moved :
        missingLinks = getMissingLinks(config)
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
        if len(candidates) == 0 :
          break
        candidate = candidates[0][1]
        if debug :
          config.printForDebug(sys.stderr)
          print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
        moved = applyTransition(ts, strat, config, candidate)
    
      EOS.apply(config)
    ################################################################################
    
    ################################################################################
    def decodeModel(ts, strat, config, network, dicts, debug) :
      EOS = Transition("EOS")
      config.moveWordIndex(0)
      moved = True
      network.eval()
      with torch.no_grad():
        while moved :
          features = extractFeatures(dicts, config).unsqueeze(0)
          output = torch.nn.functional.softmax(network(features), dim=1)
          candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
          candidates = [cand[2] for cand in candidates if cand[0]]
          if len(candidates) == 0 :
            break
          candidate = candidates[0]
          if debug :
            config.printForDebug(sys.stderr)
            print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
          moved = applyTransition(ts, strat, config, candidate)
    
      EOS.apply(config)
    ################################################################################