diff --git a/Rl.py b/Rl.py
index 6f30f407e39f8746479503fb98b4aa13dd8647a4..432a1a7027361363361a4db798fef9072fa9321e 100644
--- a/Rl.py
+++ b/Rl.py
@@ -51,8 +51,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
 ################################################################################
 
 ################################################################################
-def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
-  gamma = 0.8
+def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
   if len(memory) < batchSize :
     return 0.0
 
diff --git a/Train.py b/Train.py
index 9a5709163c07b487974a4e929f92cf20f8ff831d..5f0ed38496671e3c524cb64293fe8abef04257de 100644
--- a/Train.py
+++ b/Train.py
@@ -16,15 +16,15 @@ import Config
 from conll18_ud_eval import load_conllu, evaluate
 
 ################################################################################
-def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, silent=False) :
+def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) :
   sentences = Config.readConllu(filename)
 
   if type == "oracle" :
-    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, silent)
+    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent)
     return
 
   if type == "rl":
-    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, silent)
+    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent)
     return
 
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
@@ -92,7 +92,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, silent=False) :
+def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS"], 2)
   dicts.addDict("HISTORY", {**{t.name : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
   examples = torch.stack(examples)
 
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
-  optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
+  optimizer = torch.optim.Adam(network.parameters(), lr=lr)
   lossFct = torch.nn.CrossEntropyLoss()
   bestLoss = None
   bestScore = None
@@ -147,7 +147,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, silent=False) :
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) :
 
   memory = None
   dicts = Dicts()
@@ -160,7 +160,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
   policy_net.train()
-  optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
+  optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
 
   bestLoss = None
@@ -171,8 +171,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   sentIndex = 0
 
   for epoch in range(1,nbIter+1) :
-    probaRandom = round(0.5*math.exp((-epoch+1)/4)+0.1, 2)
-    probaOracle = round(0.3*math.exp((-epoch+1)/2), 2)
+    probaRandom = round((probas[0][0]-probas[0][2])*math.exp((-epoch+1)/probas[0][1])+probas[0][2], 2)
+    probaOracle = round((probas[1][0]-probas[1][2])*math.exp((-epoch+1)/probas[1][1])+probas[1][2], 2)
     i = 0
     totalLoss = 0.0
     while True :
@@ -214,7 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
         state = newState
         if i % batchSize == 0 :
-          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
+          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
           if i % (1*batchSize) == 0 :
             target_net.load_state_dict(policy_net.state_dict())
             target_net.eval()
diff --git a/main.py b/main.py
index 81d311882083c6885732aaa3602085127ccbb666..d8fb72a6fa9e59a00eb21afe68b9c0675c56dab3 100755
--- a/main.py
+++ b/main.py
@@ -29,6 +29,10 @@ if __name__ == "__main__" :
     help="Size of each batch.")
   parser.add_argument("--seed", default=100,
     help="Random seed.")
+  parser.add_argument("--lr", default=0.0001,
+    help="Learning rate.")
+  parser.add_argument("--gamma", default=0.99,
+    help="Importance given to future rewards.")
   parser.add_argument("--bootstrap", default=None,
     help="If not none, extract examples in bootstrap mode (oracle train only).")
   parser.add_argument("--dev", default=None,
@@ -43,6 +47,10 @@ if __name__ == "__main__" :
     help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
   parser.add_argument("--reward", default="A",
     help="Reward function to use (A,B,C,D,E)")
+  parser.add_argument("--probaRandom", default="0.6,4,0.1",
+    help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
+  parser.add_argument("--probaOracle", default="0.3,2,0.0",
+    help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
   args = parser.parse_args()
 
   if args.debug :
@@ -65,7 +73,8 @@ if __name__ == "__main__" :
     json.dump([t.name for t in transitionSet], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     print("Transition Set :", [trans.name for trans in transitionSet], file=sys.stderr)
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, args.silent)
+    probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
+    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
   elif args.mode == "decode" :
     transNames = json.load(open(args.model+"/transitions.json", "r"))
     transitionSet = [Transition(elem) for elem in transNames]