Skip to content
Snippets Groups Projects
Decode.py 2.45 KiB
Newer Older
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
import torch

################################################################################
def randomDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  while True :
    candidates = [trans for trans in ts if trans.appliable(config)]
    if len(candidates) == 0 :
      break
    candidate = candidates[random.randint(0, 100) % len(candidates)]
    if debug :
      config.printForDebug(sys.stderr)
      print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
    applyTransition(ts, strat, config, candidate.name)

  EOS.apply(config)
################################################################################

################################################################################
def oracleDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)
################################################################################

################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  network.eval()
  with torch.no_grad():
    while moved :
      features = extractFeatures(dicts, config).unsqueeze(0)
      output = torch.nn.functional.softmax(network(features), dim=1)
      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
      candidates = [cand[2] for cand in candidates if cand[0]]
      if len(candidates) == 0 :
        break
      candidate = candidates[0]
      if debug :
        config.printForDebug(sys.stderr)
        print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
      moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)
################################################################################