import random import sys from Transition import Transition, getMissingLinks, applyTransition from Features import extractFeatures from Dicts import Dicts import Config import torch ################################################################################ def randomDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) while True : candidates = [trans for trans in ts if trans.appliable(config)] if len(candidates) == 0 : break candidate = candidates[random.randint(0, 100) % len(candidates)] if debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) applyTransition(ts, strat, config, candidate.name) EOS.apply(config) ################################################################################ ################################################################################ def oracleDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) ################################################################################ ################################################################################ def decodeModel(ts, strat, config, network, dicts, debug) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True network.eval() with torch.no_grad(): while moved : features = extractFeatures(dicts, config).unsqueeze(0) output = torch.nn.functional.softmax(network(features), dim=1) candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1] candidates = [cand[2] for cand in candidates if cand[0]] if len(candidates) == 0 : break candidate = candidates[0] if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) ################################################################################ ################################################################################ def decodeMode(debug, filename, type, modelDir = None, network=None, dicts=None, output=sys.stdout) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type in ["random", "oracle"] : decodeFunc = oracleDecode if type == "oracle" else randomDecode for config in sentences : decodeFunc(transitionSet, strategy, config, debug) sentences[0].print(sys.stdout, header=True) for config in sentences[1:] : config.print(sys.stdout, header=False) elif type == "model" : if dicts is None : dicts = Dicts() dicts.load(modelDir+"/dicts.json") network = torch.load(modelDir+"/network.pt") for config in sentences : decodeModel(transitionSet, strategy, config, network, dicts, debug) sentences[0].print(output, header=True) for config in sentences[1:] : config.print(output, header=False) else : print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################