import random import sys from Transition import Transition, getMissingLinks, applyTransition from Features import extractFeatures import torch ################################################################################ def randomDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) while True : candidates = [trans for trans in ts if trans.appliable(config)] if len(candidates) == 0 : break candidate = candidates[random.randint(0, 100) % len(candidates)] if debug : config.printForDebug(sys.stderr) print(candidate.name+"\n"+("-"*80)+"\n", file=sys.stderr) applyTransition(ts, strat, config, candidate.name) EOS.apply(config) ################################################################################ ################################################################################ def oracleDecode(ts, strat, config, debug=False) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True while moved : missingLinks = getMissingLinks(config) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) if len(candidates) == 0 : break candidate = candidates[0][1] if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) ################################################################################ ################################################################################ def decodeModel(ts, strat, config, network, dicts, debug) : EOS = Transition("EOS") config.moveWordIndex(0) moved = True network.eval() with torch.no_grad(): while moved : features = extractFeatures(dicts, config).unsqueeze(0) output = torch.nn.functional.softmax(network(features), dim=1) candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1] candidates = [cand[2] for cand in candidates if cand[0]] if len(candidates) == 0 : break candidate = candidates[0] if debug : config.printForDebug(sys.stderr) print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) moved = applyTransition(ts, strat, config, candidate) EOS.apply(config) ################################################################################