import random import sys from Transition import Transition, getMissingLinks, applyTransition from Dicts import Dicts from Util import getDevice import Config import torch ################################################################################ def randomDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) while True : candidates = [trans for trans in ts if trans.appliable(config)] if len(candidates) == 0 : break candidate = candidates[random.randint(0, 100) % len(candidates)] if debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) applyTransition(ts, strat, config, candidate.name, 0.) EOS.apply(config) ################################################################################ ################################################################################ def oracleDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate, 0.) EOS.apply(config) ################################################################################ ################################################################################ def decodeModel(ts, strat, config, network, dicts, debug) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True network.eval() currentDevice = network.currentDevice() decodeDevice = getDevice() network.to(decodeDevice) if debug : print("\n"+("-"*80)+"\n", file=sys.stderr) with torch.no_grad(): while moved : features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice) output = network(features) scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] candidates = [[cand[0],cand[2]] for cand in scores if cand[1]] if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+"Chosen action : %s"%candidate+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate, 0.) EOS.apply(config, strat) network.to(currentDevice) ################################################################################ ################################################################################ def decodeMode(debug, filename, type, transitionSet, strategy, modelDir=None, network=None, dicts=None, output=sys.stdout) : sentences = Config.readConllu(filename) if type in ["random", "oracle"] : decodeFunc = oracleDecode if type == "oracle" else randomDecode for config in sentences : decodeFunc(transitionSet, strategy, config, debug) sentences[0].print(sys.stdout, header=True) for config in sentences[1:] : config.print(sys.stdout, header=False) elif type == "model" : if dicts is None : dicts = Dicts() dicts.load(modelDir+"/dicts.json") network = torch.load(modelDir+"/network.pt") for config in sentences : decodeModel(transitionSet, strategy, config, network, dicts, debug) sentences[0].print(output, header=True) for config in sentences[1:] : config.print(output, header=False) else : print("ERROR : unknown type '%s'"%type, file=sys.stderr) exit(1) ################################################################################