#! /usr/bin/env python3 import sys import random import argparse import Config from Transition import Transition, getMissingLinks ################################################################################ def applyTransition(ts, strat, config, name) : transition = [trans for trans in ts if trans.name == name][0] movement = strat[transition.name] transition.apply(config) return config.moveWordIndex(movement) ################################################################################ ################################################################################ def randomDecode(ts, strat, config) : EOS = Transition("EOS") config.moveWordIndex(0) while True : candidates = [trans for trans in transitionSet if trans.appliable(config)] if len(candidates) == 0 : break candidate = candidates[random.randint(0, 100) % len(candidates)] if args.debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) applyTransition(transitionSet, strategy, config, candidate.name) EOS.apply(config) ################################################################################ ################################################################################ def oracleDecode(ts, strat, config) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] if args.debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(transitionSet, strategy, config, candidate) EOS.apply(config) ################################################################################ ################################################################################ if __name__ == "__main__" : parser = argparse.ArgumentParser() parser.add_argument("trainCorpus", type=str, help="Name of the CoNLL-U training file.") parser.add_argument("--debug", "-d", default=False, action="store_true", help="Print debug infos on stderr.") args = parser.parse_args() transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(sys.argv[1]) first = True for config in sentences : oracleDecode(transitionSet, strategy, config) config.print(sys.stdout, header=first) first = False ################################################################################