From ead830ccfd9d8d7bfccd457af0d1961bbd11a3e6 Mon Sep 17 00:00:00 2001 From: Maxime Petit <maxime.petit.3@etu.univ-amu.fr> Date: Thu, 17 Jun 2021 14:33:53 +0200 Subject: [PATCH] New parameters : Nb of unaplayable actions picked before breaking the analysis --- Train.py | 21 +++++++++++++-------- main.py | 8 +++++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Train.py b/Train.py index 01c5aaf..8b948d2 100644 --- a/Train.py +++ b/Train.py @@ -16,7 +16,7 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) : +def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : sentences = Config.readConllu(filename, predicted) if type == "oracle" : @@ -24,7 +24,7 @@ def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, predicted, silent) + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -56,15 +56,15 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) : candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1] if debug : print(str(candidate), file=sys.stderr) - + goldIndex = [str(trans) for trans in ts].index(str(candidateOracle)) example = torch.cat([torch.LongTensor([goldIndex]), features]) examples.append(example) moved = applyTransition(strat, config, candidate, None) - + EOS.apply(config, strat) - + return examples ################################################################################ @@ -149,7 +149,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) : memory = None dicts = Dicts() @@ -189,6 +189,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentence.moveWordIndex(0) state = policy_net.extractFeatures(dicts, sentence).to(getDevice()) + count = 0 + while True : missingLinks = getMissingLinks(sentence) if debug : @@ -206,11 +208,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc) reward = torch.FloatTensor([reward_]).to(getDevice()) - newState = None + #newState = None if appliable : applyTransition(strategy, sentence, action, reward_) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) + else: + count+=1 + if memory is None : memory = ReplayMemory(5000, state.numel()) memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) @@ -223,7 +228,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti policy_net.train() i += 1 - if state is None : + if state is None or count == countBreak: break if i >= nbExByEpoch : break diff --git a/main.py b/main.py index 874ce0f..c82affe 100755 --- a/main.py +++ b/main.py @@ -41,7 +41,7 @@ if __name__ == "__main__" : parser.add_argument("--gamma", default=0.99, help="Importance given to future rewards.") parser.add_argument("--bootstrap", default=None, - help="If not none, extract examples in bootstrap mode (oracle train only).") + help="If not none, extract examples in bootstrap mode every n epochs (oracle train only).") parser.add_argument("--dev", default=None, help="Name of the CoNLL-U file of the dev corpus.") parser.add_argument("--incr", "-i", default=False, action="store_true", @@ -51,7 +51,7 @@ if __name__ == "__main__" : parser.add_argument("--silent", "-s", default=False, action="store_true", help="Don't print advancement infos.") parser.add_argument("--transitions", default="eager", - help="Transition set to use (eager | swift).") + help="Transition set to use (eager | swift | tagparser).") parser.add_argument("--ts", default="", help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--reward", default="A", @@ -60,6 +60,8 @@ if __name__ == "__main__" : help="Evolution of probability to chose action at random : (start value, decay speed, end value)") parser.add_argument("--probaOracle", default="0.3,2,0.0", help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)") + parser.add_argument("--countBreak", default=1, + help="Number of unaplayable transition picked before breaking the analysis.") args = parser.parse_args() if args.debug : @@ -99,7 +101,7 @@ if __name__ == "__main__" : json.dump(strategy, open(args.model+"/strategy.json", "w")) printTS(transitionSet, sys.stderr) probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.predicted, args.silent) + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames] -- GitLab