diff --git a/main.py b/main.py
index c82affe6e8b6ab97d6643b85c4a44d2fc7e9e086..705d375f72e61adaae902cb388905891e28a74ce 100755
--- a/main.py
+++ b/main.py
@@ -85,25 +85,27 @@ if __name__ == "__main__" :
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
     transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
-    args.predicted = "HEAD,UPOS"
+    args.predictedStr = "HEAD,UPOS"
   elif args.transitions == "swift" :
     transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
-    args.predicted = "HEAD"
+    args.predictedStr = "HEAD"
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
 
   strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
 
-  args.predicted = set({colName for colName in args.predicted.split(',')})
-
   if args.mode == "train" :
-    json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w"))
+    args.predicted = set({colName for colName in args.predictedStr.split(',')})
+    json.dump([args.predictedStr, [str(t) for t in transitionSet]], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     printTS(transitionSet, sys.stderr)
     probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
     Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
   elif args.mode == "decode" :
-    transNames = json.load(open(args.model+"/transitions.json", "r"))
+    transInfos = json.load(open(args.model+"/transitions.json", "r"))
+    transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
+    args.predictedStr = transInfos[0]
+    args.predicted = set({colName for colName in args.predictedStr.split(',')})
     transitionSet = [Transition(elem) for elem in transNames]
     strategy = json.load(open(args.model+"/strategy.json", "r"))
     printTS(transitionSet, sys.stderr)