Newer
Older
import sys
import random
from Transition import Transition, getMissingLinks, applyTransition
import Features
import torch
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
examples = []
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
return examples
################################################################################