diff --git a/Rl.py b/Rl.py index 6f30f407e39f8746479503fb98b4aa13dd8647a4..432a1a7027361363361a4db798fef9072fa9321e 100644 --- a/Rl.py +++ b/Rl.py @@ -51,8 +51,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ ################################################################################ -def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : - gamma = 0.8 +def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : if len(memory) < batchSize : return 0.0 diff --git a/Train.py b/Train.py index 9a5709163c07b487974a4e929f92cf20f8ff831d..5f0ed38496671e3c524cb64293fe8abef04257de 100644 --- a/Train.py +++ b/Train.py @@ -16,15 +16,15 @@ import Config from conll18_ud_eval import load_conllu, evaluate ################################################################################ -def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) : +def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) : sentences = Config.readConllu(filename) if type == "oracle" : - trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent) + trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent) return if type == "rl": - trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent) + trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent) return print("ERROR : unknown type '%s'"%type, file=sys.stderr) @@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss ################################################################################ ################################################################################ -def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) : +def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) : dicts = Dicts() dicts.readConllu(filename, ["FORM","UPOS"], 2) dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}}) @@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr examples = torch.stack(examples) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr) - optimizer = torch.optim.Adam(network.parameters(), lr=0.0001) + optimizer = torch.optim.Adam(network.parameters(), lr=lr) lossFct = torch.nn.CrossEntropyLoss() bestLoss = None bestScore = None @@ -147,7 +147,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ################################################################################ ################################################################################ -def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) : +def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) : memory = None dicts = Dicts() @@ -160,7 +160,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti target_net.load_state_dict(policy_net.state_dict()) target_net.eval() policy_net.train() - optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) + optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr) print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) bestLoss = None @@ -171,8 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentIndex = 0 for epoch in range(1,nbIter+1) : - probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2) - probaOracle = round(0.3*math.exp((-epoch+1)/2), 2) + probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2) + probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2) i = 0 totalLoss = 0.0 while True : @@ -214,7 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : - totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) + totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) if i % (1*batchSize) == 0 : target_net.load_state_dict(policy_net.state_dict()) target_net.eval() diff --git a/main.py b/main.py index 81d311882083c6885732aaa3602085127ccbb666..d8fb72a6fa9e59a00eb21afe68b9c0675c56dab3 100755 --- a/main.py +++ b/main.py @@ -29,6 +29,10 @@ if __name__ == "__main__" : help="Size of each batch.") parser.add_argument("--seed", default=100, help="Random seed.") + parser.add_argument("--lr", default=0.0001, + help="Learning rate.") + parser.add_argument("--gamma", default=0.99, + help="Importance given to future rewards.") parser.add_argument("--bootstrap", default=None, help="If not none, extract examples in bootstrap mode (oracle train only).") parser.add_argument("--dev", default=None, @@ -43,6 +47,10 @@ if __name__ == "__main__" : help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"") parser.add_argument("--reward", default="A", help="Reward function to use (A,B,C,D,E)") + parser.add_argument("--probaRandom", default="0.6,4,0.1", + help="Evolution of probability to chose action at random : (start value, decay speed, end value)") + parser.add_argument("--probaOracle", default="0.3,2,0.0", + help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)") args = parser.parse_args() if args.debug : @@ -65,7 +73,8 @@ if __name__ == "__main__" : json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w")) json.dump(strategy, open(args.model+"/strategy.json", "w")) print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr) - Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent) + probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))] + Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent) elif args.mode == "decode" : transNames = json.load(open(args.model+"/transitions.json", "r")) transitionSet = [Transition(elem) for elem in transNames]