Skip to content
Snippets Groups Projects
Commit b1c976ae authored by Franck Dary's avatar Franck Dary
Browse files

Added bootstrap mode for oracle training

parent e0926705
Branches
No related tags found
No related merge requests found
......@@ -15,14 +15,14 @@ import Config
from conll18_ud_eval import load_conllu, evaluate
################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) :
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename)
if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent)
return
if type == "rl":
......@@ -34,25 +34,36 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
################################################################################
################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) :
def extractExamples(debug, ts, strat, config, dicts, network=None) :
examples = []
with torch.no_grad() :
EOS = Transition("EOS")
config.moveWordIndex(0)
moved = True
while moved :
missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)])
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 :
break
candidate = candidates[0][1]
candidateIndex = [trans.name for trans in ts].index(candidate)
best = min([cand[0] for cand in candidates])
candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features])
examples.append(example)
candidate = candidateOracle.name
if debug :
config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr)
print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if network is not None :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
print(candidate.name, file=sys.stderr)
goldIndex = [trans.name for trans in ts].index(candidateOracle.name)
candidateIndex = [trans.name for trans in ts].index(candidate)
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config)
......@@ -82,14 +93,15 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################
################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
examples = []
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json")
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug)
examples += extractExamples(debug, transitionSet, strategy, config, dicts)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
......@@ -100,6 +112,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
bestLoss = None
bestScore = None
for epoch in range(1,nbEpochs+1) :
if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0
......
......@@ -27,6 +27,8 @@ if __name__ == "__main__" :
help="Size of each batch.")
parser.add_argument("--seed", default=100,
help="Random seed.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true",
......@@ -43,7 +45,7 @@ if __name__ == "__main__" :
torch.manual_seed(args.seed)
if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent)
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
elif args.mode == "decode" :
Decode.decodeMode(args.debug, args.corpus, args.type, args.model)
else :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment