Skip to content
Snippets Groups Projects
Decode.py 3.86 KiB
Newer Older
import random
import sys
from Transition import Transition, getMissingLinks, applyTransition
from Dicts import Dicts
import torch

################################################################################
def randomDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  while True :
    candidates = [trans for trans in ts if trans.appliable(config)]
    if len(candidates) == 0 :
      break
    candidate = candidates[random.randint(0, 100) % len(candidates)]
    if debug :
      config.printForDebug(sys.stderr)
      print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
Maxime Petit's avatar
Maxime Petit committed
    applyTransition(ts, strat, config, candidate.name, 0.)

  EOS.apply(config)
################################################################################

################################################################################
def oracleDecode(ts, strat, config, debug=False) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    if debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
Maxime Petit's avatar
Maxime Petit committed
    moved = applyTransition(ts, strat, config, candidate, 0.)

  EOS.apply(config)
################################################################################

################################################################################
def decodeModel(ts, strat, config, network, dicts, debug) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  network.eval()

  currentDevice = network.currentDevice()
  decodeDevice = getDevice()
  network.to(decodeDevice)

Maxime Petit's avatar
Maxime Petit committed
  if debug :
    print("\n"+("-"*80)+"\n", file=sys.stderr)

  with torch.no_grad():
    while moved :
      features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
      output = network(features)
      scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
Franck Dary's avatar
Franck Dary committed
      candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
      if len(candidates) == 0 :
        break
Franck Dary's avatar
Franck Dary committed
      candidate = candidates[0][1]
      if debug :
        config.printForDebug(sys.stderr)
        print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr)
Maxime Petit's avatar
Maxime Petit committed
      moved = applyTransition(ts, strat, config, candidate, 0.)
  EOS.apply(config, strat)
################################################################################

################################################################################
def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, network=None, dicts=None, output=sys.stdout) :

  sentences = Config.readConllu(filename)

  if type in ["random", "oracle"] :
    decodeFunc = oracleDecode if type == "oracle" else randomDecode
    for config in sentences :
      decodeFunc(transitionSet, strategy, config, debug)
    sentences[0].print(sys.stdout, header=True)
    for config in sentences[1:] :
      config.print(sys.stdout, header=False)
  elif type == "model" :
    if dicts is None :
      dicts = Dicts()
      dicts.load(modelDir+"/dicts.json")
      network = torch.load(modelDir+"/network.pt")
    for config in sentences :
      decodeModel(transitionSet, strategy, config, network, dicts, debug)
    sentences[0].print(output, header=True)
    for config in sentences[1:] :
      config.print(output, header=False)
  else :
    print("ERROR : unknown type '%s'"%type, file=sys.stderr)
    exit(1)
################################################################################