diff --git a/Train.py b/Train.py
index a5afe28f23e3843e5a3ef9baf050571b4d238306..140c8960fc33b1278b624fcf01163bea2ecdf5af 100644
--- a/Train.py
+++ b/Train.py
@@ -229,14 +229,18 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
       state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
 
       count = 0
+      list_probas = []
+      for pb in range(len(probas)):
+        list_probas.append([round((probas[pb][0][0]-probas[pb][0][2])*math.exp((-epoch+1)/probas[pb][0][1])+probas[pb][0][2], 2),
+                           round((probas[pb][1][0]-probas[pb][1][2])*math.exp((-epoch+1)/probas[pb][1][1])+probas[pb][1][2], 2)])
 
       while True :
         missingLinks = getMissingLinks(sentence)
         transitionSet = transitionSets[sentence.state]
         fromState = sentence.state
         toState = sentence.state
-        probaRandom = round((probas[fromState][0][0]-probas[fromState][0][2])*math.exp((-epoch+1)/probas[fromState][0][1])+probas[fromState][0][2], 2)
-        probaOracle = round((probas[fromState][1][0]-probas[fromState][1][2])*math.exp((-epoch+1)/probas[fromState][1][1])+probas[fromState][1][2], 2)
+        probaRandom = list_probas[fromState][0]
+        probaOracle = list_probas[fromState][1]
         
 
         if debug :
diff --git a/main.py b/main.py
index 56458896315091fe649200cb3ab6f10718251f22..594e0fd606ba3f20cd61db826c939b2d54b24f24 100755
--- a/main.py
+++ b/main.py
@@ -63,6 +63,8 @@ if __name__ == "__main__" :
     help="Evolution of probability to chose action at random : (start value, decay speed, end value)")
   parser.add_argument("--probaOracle", default="0.3,2,0.0",
     help="Evolution of probability to chose action from oracle : (start value, decay speed, end value)")
+  parser.add_argument("--probaStateBack", default="0.0,20,1.0-1.0,20,0.0",
+    help="Evolution of probability to chose action in state Back with random and oracle.")
   parser.add_argument("--countBreak", default=1,
     help="Number of unaplayable transition picked before breaking the analysis.")
   args = parser.parse_args()
@@ -89,7 +91,7 @@ if __name__ == "__main__" :
     args.states = ["tagger"]
     strategy = {"TAG" : (1,0)}
     args.network = "tagger"
-    args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "taggerbt" :
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
@@ -99,13 +101,14 @@ if __name__ == "__main__" :
     args.states = ["tagger", "backer"]
     strategy = {"TAG" : (1,1), "NOBACK" : (0,0)}
     args.network = "tagger"
-    args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
   elif args.transitions == "eager" :
     transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["parser"]
     strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
-    args.probas = [[[0.6,4,0.1],[0.3,2,0.0]]]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparser" :
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
@@ -114,7 +117,8 @@ if __name__ == "__main__" :
     args.predictedStr = "HEAD,UPOS"
     args.states = ["tagger", "parser"]
     strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
-    args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]]]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   elif args.transitions == "tagparserbt" :
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
@@ -123,12 +127,16 @@ if __name__ == "__main__" :
     args.predictedStr = "HEAD,UPOS"
     args.states = ["tagger", "parser", "backer"]
     strategy = {"RIGHT" : (1,2), "SHIFT" : (1,2), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1), "NOBACK" : (0,0)}
-    args.probas = [[[0.6,4,0.1],[0.3,2,0.0]],[[0.6,4,0.1],[0.3,2,0.0]],[[0.0,25,1.0],[1.0,25,0.0]]]
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))],
+              [list(map(float, args.probaStateBack.split('-')[0].split(','))), list(map(float, args.probaStateBack.split('-')[1].split(',')))]]
+
   elif args.transitions == "swift" :
     transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD"
     args.states = ["parser"]
     strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
+    probas = [[list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]]
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
 
@@ -137,8 +145,7 @@ if __name__ == "__main__" :
     json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     printTS(transitionSets, sys.stderr)
-    probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
-    Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), args.probas, int(args.countBreak), args.predicted, args.silent)
+    Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
   elif args.mode == "decode" :
     transInfos = json.load(open(args.model+"/transitions.json", "r"))
     transNames = json.load(open(args.model+"/transitions.json", "r"))[1]