Newer
Older
import argparse
Franck Dary
committed
from Transition import Transition, getMissingLinks
################################################################################
def applyTransition(ts, strat, config, name) :
transition = [trans for trans in ts if trans.name == name][0]
movement = strat[transition.name]
transition.apply(config)
Franck Dary
committed
return config.moveWordIndex(movement)
################################################################################
################################################################################
def randomDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
Franck Dary
committed
while True :
candidates = [trans for trans in transitionSet if trans.appliable(config)]
Franck Dary
committed
if len(candidates) == 0 :
break
candidate = candidates[random.randint(0, 100) % len(candidates)]
Franck Dary
committed
if args.debug :
config.printForDebug(sys.stderr)
print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr)
applyTransition(transitionSet, strategy, config, candidate.name)
Franck Dary
committed
EOS.apply(config)
################################################################################
################################################################################
def oracleDecode(ts, strat, config) :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in transitionSet if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
if args.debug :
config.printForDebug(sys.stderr)
Franck Dary
committed
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(transitionSet, strategy, config, candidate)
EOS.apply(config)
################################################################################
################################################################################
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("trainCorpus", type=str,
help="Name of the CoNLL-U training file.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
help="Print debug infos on stderr.")
args = parser.parse_args()
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(sys.argv[1])
first = True
Franck Dary
committed
oracleDecode(transitionSet, strategy, config)
config.print(sys.stdout, header=first)
first = False
################################################################################