diff --git a/Train.py b/Train.py index 11215ed8665405de21bab5ff4314697a59780baf..3309201bab0f9095d1b8d8de1373aa6064b0318e 100644 --- a/Train.py +++ b/Train.py @@ -15,14 +15,14 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silent=False) : +def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, bootstrapInterval, silent=False) : transitionSet = [Transition(elem) for elem in ["RIGHT", "LEFT", "SHIFT", "REDUCE"]] strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0} sentences = Config.readConllu(filename) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent) + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, silent) return if type == "rl": @@ -34,29 +34,40 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen ################################################################################ ################################################################################ -def extractExamples(ts, strat, config, dicts, debug=False) : +def extractExamples(debug, ts, strat, config, dicts, network=None) : examples = [] - - EOS = Transition("EOS") - config.moveWordIndex(0) - moved = True - while moved : - missingLinks = getMissingLinks(config) - candidates = sorted([[trans.getOracleScore(config, missingLinks), trans.name] for trans in ts if trans.appliable(config)]) - if len(candidates) == 0 : - break - candidate = candidates[0][1] - candidateIndex = [trans.name for trans in ts].index(candidate) - features = Features.extractFeatures(dicts, config) - example = torch.cat([torch.LongTensor([candidateIndex]), features]) - examples.append(example) - if debug : - config.printForDebug(sys.stderr) - print(str(candidates)+"\n"+("-"*80)+"\n", file=sys.stderr) - moved = applyTransition(ts, strat, config, candidate) - - EOS.apply(config) - + with torch.no_grad() : + EOS = Transition("EOS") + config.moveWordIndex(0) + moved = True + while moved : + missingLinks = getMissingLinks(config) + candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) + if len(candidates) == 0 : + break + best = min([cand[0] for cand in candidates]) + candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1] + features = Features.extractFeatures(dicts, config) + candidate = candidateOracle.name + if debug : + config.printForDebug(sys.stderr) + print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr) + if network is not None : + output = network(features.unsqueeze(0).to(getDevice())) + scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1] + candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] + if debug : + print(candidate.name, file=sys.stderr) + + goldIndex = [trans.name for trans in ts].index(candidateOracle.name) + candidateIndex = [trans.name for trans in ts].index(candidate) + example = torch.cat([torch.LongTensor([goldIndex]), features]) + examples.append(example) + + moved = applyTransition(ts, strat, config, candidate) + + EOS.apply(config) + return examples ################################################################################ @@ -82,14 +93,15 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentences, silent=False) : - examples = [] +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir+"/dicts.json") + examples = [] + sentences = copy.deepcopy(sentencesOriginal) print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr) for config in sentences : - examples += extractExamples(transitionSet, strategy, config, dicts, debug) + examples += extractExamples(debug, transitionSet, strategy, config, dicts) print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) examples = torch.stack(examples) @@ -100,6 +112,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr bestLoss = None bestScore = None for epoch in range(1,nbEpochs+1) : + if bootstrapInterval is not None and epoch > 1 and (epoch-1) % bootstrapInterval == 0 : + examples = [] + sentences = copy.deepcopy(sentencesOriginal) + print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr) + for config in sentences : + examples += extractExamples(debug, transitionSet, strategy, config, dicts, network) + print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr) + examples = torch.stack(examples) + network.train() examples = examples.index_select(0, torch.randperm(examples.size(0))) totalLoss = 0.0 diff --git a/main.py b/main.py index 43758e05b0e320c5bd9ee125e4489bbbe35bb71a..ef8a1ea5b732954056dd9df9acdd309cd3d52d95 100755 --- a/main.py +++ b/main.py @@ -27,6 +27,8 @@ if __name__ == "__main__" : help="Size of each batch.") parser.add_argument("--seed", default=100, help="Random seed.") + parser.add_argument("--bootstrap", default=None, + help="If not none, extract examples in bootstrap mode (oracle train only).") parser.add_argument("--dev", default=None, help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--debug", "-d", default=False, action="store_true", @@ -43,7 +45,7 @@ if __name__ == "__main__" : torch.manual_seed(args.seed) if args.mode == "train" : - Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent) elif args.mode == "decode" : Decode.decodeMode(args.debug, args.corpus, args.type, args.model) else :