Skip to content
Snippets Groups Projects
Decode.py 3.7 KiB
Newer Older
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Features import extractFeatures
from Dicts import Dicts
import torch

################################################################################
def randomDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  while True :
    candidates = [trans for trans in ts if trans.appliable(config)]
    if len(candidates) == 0 :
      break
    candidate = candidates[random.randint(0, 100) % len(candidates)]
    if debug :
      config.printForDebug(sys.stderr)
      print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
    applyTransition(ts, strat, config, candidate.name)

  EOS.apply(config)
################################################################################

################################################################################
def oracleDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)
################################################################################

################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  network.eval()
  with torch.no_grad():
    while moved :
      features = extractFeatures(dicts, config).unsqueeze(0)
      output = torch.nn.functional.softmax(network(features), dim=1)
      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
      candidates = [cand[2] for cand in candidates if cand[0]]
      if len(candidates) == 0 :
        break
      candidate = candidates[0]
      if debug :
        config.printForDebug(sys.stderr)
        print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
      moved = applyTransition(ts, strat, config, candidate)

  EOS.apply(config)
################################################################################

################################################################################
def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) :
  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(filename)

  if type in ["random", "oracle"] :
    decodeFunc = oracleDecode if type == "oracle" else randomDecode
    for config in sentences :
      decodeFunc(transitionSet, strategy, config, debug)
    sentences[0].print(sys.stdout, header=True)
    for config in sentences[1:] :
      config.print(sys.stdout, header=False)
  elif type == "model" :
    if dicts is None :
      dicts = Dicts()
      dicts.load(modelDir+"/dicts.json")
      network = torch.load(modelDir+"/network.pt")
    for config in sentences :
      decodeModel(transitionSet, strategy, config, network, dicts, debug)
    sentences[0].print(output, header=True)
    for config in sentences[1:] :
      config.print(output, header=False)
  else :
    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
    exit(1)
################################################################################