Skip to content
Snippets Groups Projects
Commit b1c976ae authored by Franck Dary's avatar Franck Dary
Browse files

Added bootstrap mode for oracle training

parent e0926705
No related branches found
No related tags found
No related merge requests found
...@@ -15,14 +15,14 @@ import Config ...@@ -15,14 +15,14 @@ import Config
from conll18_ud_eval import load_conllu, evaluate from conll18_ud_eval import load_conllu, evaluate
################################################################################ ################################################################################
def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) :
transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]]
strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
sentences = Config.readConllu(filename) sentences = Config.readConllu(filename)
if type == "oracle" : if type == "oracle" :
trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent)
return return
if type == "rl": if type == "rl":
...@@ -34,25 +34,36 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen ...@@ -34,25 +34,36 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
################################################################################ ################################################################################
################################################################################ ################################################################################
def extractExamples(ts, strat, config, dicts, debug=False) : def extractExamples(debug, ts, strat, config, dicts, network=None) :
examples = [] examples = []
with torch.no_grad() :
EOS = Transition("EOS") EOS = Transition("EOS")
config.moveWordIndex(0) config.moveWordIndex(0)
moved = True moved = True
while moved : while moved :
missingLinks = getMissingLinks(config) missingLinks = getMissingLinks(config)
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
if len(candidates) == 0 : if len(candidates) == 0 :
break break
candidate = candidates[0][1] best = min([cand[0] for cand in candidates])
candidateIndex = [trans.name for trans in ts].index(candidate) candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
features = Features.extractFeatures(dicts, config) features = Features.extractFeatures(dicts, config)
example = torch.cat([torch.LongTensor([candidateIndex]), features]) candidate = candidateOracle.name
examples.append(example)
if debug : if debug :
config.printForDebug(sys.stderr) config.printForDebug(sys.stderr)
print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
if network is not None :
output = network(features.unsqueeze(0).to(getDevice()))
scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
if debug :
print(candidate.name, file=sys.stderr)
goldIndex = [trans.name for trans in ts].index(candidateOracle.name)
candidateIndex = [trans.name for trans in ts].index(candidate)
example = torch.cat([torch.LongTensor([goldIndex]), features])
examples.append(example)
moved = applyTransition(ts, strat, config, candidate) moved = applyTransition(ts, strat, config, candidate)
EOS.apply(config) EOS.apply(config)
...@@ -82,14 +93,15 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ...@@ -82,14 +93,15 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
examples = []
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir+"/dicts.json") dicts.save(modelDir+"/dicts.json")
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
for config in sentences : for config in sentences :
examples += extractExamples(transitionSet, strategy, config, dicts, debug) examples += extractExamples(debug, transitionSet, strategy, config, dicts)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples) examples = torch.stack(examples)
...@@ -100,6 +112,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -100,6 +112,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
bestLoss = None bestLoss = None
bestScore = None bestScore = None
for epoch in range(1,nbEpochs+1) : for epoch in range(1,nbEpochs+1) :
if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 :
examples = []
sentences = copy.deepcopy(sentencesOriginal)
print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
for config in sentences :
examples += extractExamples(debug, transitionSet, strategy, config, dicts, network)
print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
examples = torch.stack(examples)
network.train() network.train()
examples = examples.index_select(0, torch.randperm(examples.size(0))) examples = examples.index_select(0, torch.randperm(examples.size(0)))
totalLoss = 0.0 totalLoss = 0.0
......
...@@ -27,6 +27,8 @@ if __name__ == "__main__" : ...@@ -27,6 +27,8 @@ if __name__ == "__main__" :
help="Size of each batch.") help="Size of each batch.")
parser.add_argument("--seed", default=100, parser.add_argument("--seed", default=100,
help="Random seed.") help="Random seed.")
parser.add_argument("--bootstrap", default=None,
help="If not none, extract examples in bootstrap mode (oracle train only).")
parser.add_argument("--dev", default=None, parser.add_argument("--dev", default=None,
help="Name of the CoNLL-U file of the dev corpus.") help="Name of the CoNLL-U file of the dev corpus.")
parser.add_argument("--debug", "-d", default=False, action="store_true", parser.add_argument("--debug", "-d", default=False, action="store_true",
...@@ -43,7 +45,7 @@ if __name__ == "__main__" : ...@@ -43,7 +45,7 @@ if __name__ == "__main__" :
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if args.mode == "train" : if args.mode == "train" :
Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
elif args.mode == "decode" : elif args.mode == "decode" :
Decode.decodeMode(args.debug, args.corpus, args.type, args.model) Decode.decodeMode(args.debug, args.corpus, args.type, args.model)
else : else :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment