Skip to content
Snippets Groups Projects
main.py 2.69 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#! /usr/bin/env python3

import sys
import random
Franck Dary's avatar
Franck Dary committed
import Config
from Transition import Transition, getMissingLinks
Franck Dary's avatar
Franck Dary committed

################################################################################
def applyTransition(ts, strat, config, name) :
  transition = [trans for trans in ts if trans.name == name][0]
  movement = strat[transition.name]
  transition.apply(config)
Franck Dary's avatar
Franck Dary committed
################################################################################

################################################################################
def randomDecode(ts, strat, config) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
    candidates = [trans for trans in transitionSet if trans.appliable(config)]
    candidate = candidates[random.randint(0, 100) % len(candidates)]
    if args.debug :
      config.printForDebug(sys.stderr)
      print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
    applyTransition(transitionSet, strategy, config, candidate.name)

  EOS.apply(config)
################################################################################

################################################################################
def oracleDecode(ts, strat, config) :
  EOS = Transition("EOS")
  config.moveWordIndex(0)
  moved = True
  while moved :
    missingLinks = getMissingLinks(config)
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)])
    if len(candidates) == 0 :
      break
    candidate = candidates[0][1]
    if args.debug :
      config.printForDebug(sys.stderr)
      print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
    moved = applyTransition(transitionSet, strategy, config, candidate)

  EOS.apply(config)
################################################################################

Franck Dary's avatar
Franck Dary committed
################################################################################
if __name__ == "__main__" :
  parser = argparse.ArgumentParser()
  parser.add_argument("trainCorpus", type=str,
    help="Name of the CoNLL-U training file.")
  parser.add_argument("--debug", "-d", default=False, action="store_true",
    help="Print debug infos on stderr.")
  args = parser.parse_args()
Franck Dary's avatar
Franck Dary committed

  transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}

  sentences = Config.readConllu(sys.argv[1])

Franck Dary's avatar
Franck Dary committed
  for config in sentences :
    oracleDecode(transitionSet, strategy, config)
    config.print(sys.stdout, header=first)
    first = False
Franck Dary's avatar
Franck Dary committed
################################################################################