diff --git a/Config.py b/Config.py
index 3862951c569f35bc789151869ba9eac2d2ac72de..eee07f201c84228600ddb484ef48ddf913d110db 100644
--- a/Config.py
+++ b/Config.py
@@ -12,6 +12,7 @@ class Config :
     self.predicted = predicted
     self.wordIndex = 0
     self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
+    self.state = 0 #State of the analysis (e.g. 0=tagger, 1=parser)
     self.stack = []
     self.comments = []
     self.history = []
@@ -83,6 +84,7 @@ class Config :
     printedCols = ["ID","FORM","UPOS","HEAD","DEPREL"]
     left = 5
     right = 5
+    print("state :", self.state, file=output)
     print("stack :",[self.getAsFeature(ind, "ID") for ind in self.stack], file=output)
     print("history :",[str(trans) for trans in self.history[-10:]], file=output)
     print("historyPop :",[(str(c[0]),"dat:"+str(c[1]),"mvt:"+str(c[2]),"reward:"+str(c[3])) for c in self.historyPop[-10:]], file=output)
diff --git a/Decode.py b/Decode.py
index b0fcfb2378e000c464623386ea8ebdaef4fbd394..34d5ef45e3291f7a7e0936486e2809b705136617 100644
--- a/Decode.py
+++ b/Decode.py
@@ -11,6 +11,7 @@ import torch
 def randomDecode(ts, strat, config, debug=False) :
   EOS = Transition("EOS")
   config.moveWordIndex(0)
+  config.state = 0
   while True :
     candidates = [trans for trans in ts if trans.appliable(config)]
     if len(candidates) == 0 :
@@ -28,6 +29,7 @@ def randomDecode(ts, strat, config, debug=False) :
 def oracleDecode(ts, strat, config, debug=False) :
   EOS = Transition("EOS")
   config.moveWordIndex(0)
+  config.state = 0
   moved = True
   while moved :
     missingLinks = getMissingLinks(config)
@@ -44,9 +46,10 @@ def oracleDecode(ts, strat, config, debug=False) :
 ################################################################################
 
 ################################################################################
-def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
+def decodeModel(transitionSets, strat, config, network, dicts, debug, rewardFunc) :
   EOS = Transition("EOS")
   config.moveWordIndex(0)
+  config.state = 0
   moved = True
   network.eval()
 
@@ -59,6 +62,8 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
 
   with torch.no_grad():
     while moved :
+      ts = transitionSets[config.state]
+      network.setState(config.state)
       features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
       output = network(features)
       scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
diff --git a/Networks.py b/Networks.py
index 8dfd4b6edb7e9422f06a87bb336e54fc2c61522f..e894867c443497722a90e125d8989d00cb33170c 100644
--- a/Networks.py
+++ b/Networks.py
@@ -5,11 +5,12 @@ import Features
 
 ################################################################################
 class BaseNet(nn.Module):
-  def __init__(self, dicts, outputSize, incremental) :
+  def __init__(self, dicts, outputSizes, incremental) :
     super().__init__()
     self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
 
     self.incremental = incremental
+    self.state = 0
     self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
     self.historyNb = 5
     self.suffixSize = 4
@@ -19,15 +20,19 @@ class BaseNet(nn.Module):
     self.embSize = 64
     self.nbTargets = len(self.featureFunction.split())
     self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
-    self.outputSize = outputSize
+    self.outputSizes = outputSizes
     for name in dicts.dicts :
       self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
     self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
-    self.fc2 = nn.Linear(1600, outputSize)
+    for i in range(len(outputSizes)) :
+      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
     self.dropout = nn.Dropout(0.3)
 
     self.apply(self.initWeights)
 
+  def setState(self, state) :
+    self.state = state
+
   def forward(self, x) :
     embeddings = []
     for i in range(len(self.columns)) :
@@ -48,7 +53,7 @@ class BaseNet(nn.Module):
       curIndex = curIndex+self.suffixSize
     y = self.dropout(y)
     y = F.relu(self.dropout(self.fc1(y)))
-    y = self.fc2(y)
+    y = getattr(self, "output_"+str(self.state))(y)
     return y
 
   def currentDevice(self) :
@@ -70,11 +75,12 @@ class BaseNet(nn.Module):
 
 ################################################################################
 class LSTMNet(nn.Module):
-  def __init__(self, dicts, outputSize, incremental) :
+  def __init__(self, dicts, outputSizes, incremental) :
     super().__init__()
     self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
 
     self.incremental = incremental
+    self.state = 0
     self.featureFunctionLSTM = "b.-2 b.-1 b.0 b.1 b.2"
     self.featureFunction = "s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
     self.historyNb = 5
@@ -91,11 +97,15 @@ class LSTMNet(nn.Module):
     self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
     self.lstmHist = nn.LSTM(self.embSize, int(self.embSize/2), 1, batch_first=True, bidirectional = True)
     self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
-    self.fc2 = nn.Linear(1600, outputSize)
+    for i in range(len(outputSizes)) :
+      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
     self.dropout = nn.Dropout(0.3)
 
     self.apply(self.initWeights)
 
+  def setState(self, state) :
+    self.state = state
+
   def forward(self, x) :
     embeddings = []
     embeddingsLSTM = []
@@ -116,7 +126,7 @@ class LSTMNet(nn.Module):
       y = torch.cat([y, historyEmb],-1)
     y = self.dropout(y)
     y = F.relu(self.dropout(self.fc1(y)))
-    y = self.fc2(y)
+    y = getattr(self, "output_"+str(self.state))(y)
     return y
 
   def currentDevice(self) :
diff --git a/Rl.py b/Rl.py
index f870bcc29624ca364c1a338ac2dc78badd907162..4767c30b89865d2dcc87506daa169f70b6f54173 100644
--- a/Rl.py
+++ b/Rl.py
@@ -7,7 +7,7 @@ from Util import getDevice
 
 ################################################################################
 class ReplayMemory() :
-  def __init__(self, capacity, stateSize) :
+  def __init__(self, capacity, stateSize, nbStates) :
     self.capacity = capacity
     self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
     self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
diff --git a/Train.py b/Train.py
index 8b948d22ab8ce610d3b0535e8f4b4e2ed666c7fe..7c009f6763e1df48ab4995ecd8956e7eab70cb47 100644
--- a/Train.py
+++ b/Train.py
@@ -32,13 +32,16 @@ def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter,
 ################################################################################
 
 ################################################################################
-def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
-  examples = []
+# Return list of examples for each transitionSet
+def extractExamples(debug, transitionSets, strat, config, dicts, network, dynamic) :
+  examples = [[] for _ in transitionSets]
   with torch.no_grad() :
     EOS = Transition("EOS")
     config.moveWordIndex(0)
+    config.state = 0
     moved = True
     while moved :
+      ts = transitionSets[config.state]
       missingLinks = getMissingLinks(config)
       candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config) and "BACK" not in trans.name])
       if len(candidates) == 0 :
@@ -51,6 +54,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
         config.printForDebug(sys.stderr)
         print(str([[c[0],str(c[1])] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
       if dynamic :
+        netword.setState(config.state)
         output = network(features.unsqueeze(0).to(getDevice()))
         scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index]] for index in range(len(ts))])[::-1]
         candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
@@ -59,7 +63,7 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
 
       goldIndex = [str(trans) for trans in ts].index(str(candidateOracle))
       example = torch.cat([torch.LongTensor([goldIndex]), features])
-      examples.append(example)
+      examples[config.state].append(example)
 
       moved = applyTransition(strat, config, candidate, None)
 
@@ -94,19 +98,28 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
+def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
-  dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
+  transitionNames = {}
+  for ts in transitionSets :
+    for t in ts :
+      transitionNames[str(t)] = (len(transitionNames), 0)
+  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
+  dicts.addDict("HISTORY", transitionNames)
+
   dicts.save(modelDir+"/dicts.json")
-  network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
-  examples = []
+  network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
+  examples = [[] for _ in transitionSets]
   sentences = copy.deepcopy(sentencesOriginal)
   print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
   for config in sentences :
-    examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
+    extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, False)
+    for e in range(len(examples)) :
+      examples[e] += extracted[e]
   print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
-  examples = torch.stack(examples)
+  for e in range(len(examples)) :
+    examples[e] = torch.stack(examples[e])
 
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
   optimizer = torch.optim.Adam(network.parameters(), lr=lr)
@@ -119,9 +132,12 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
       sentences = copy.deepcopy(sentencesOriginal)
       print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
       for config in sentences :
-        examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
+        extracted = extractExamples(debug, transitionSets, strategy, config, dicts, network, True)
+        for e in range(len(examples)) :
+          examples[e] += extracted[e]
       print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
-      examples = torch.stack(examples)
+      for e in range(len(examples)) :
+        examples[e] = torch.stack(examples[e])
 
     network.train()
     examples = examples.index_select(0, torch.randperm(examples.size(0)))
@@ -145,20 +161,25 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
       optimizer.step()
       totalLoss += float(loss)
 
-    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
+    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
 
   memory = None
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
-  dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
+  transitionNames = {}
+  for ts in transitionSets :
+    for t in ts :
+      transitionNames[str(t)] = (len(transitionNames), 0)
+  transitionNames[dicts.nullToken] = (len(transitionNames), 0)
+  dicts.addDict("HISTORY", transitionNames)
   dicts.save(modelDir + "/dicts.json")
 
-  policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
-  target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
+  policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
+  target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
   policy_net.train()
@@ -193,6 +214,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
 
       while True :
         missingLinks = getMissingLinks(sentence)
+        transitionSet = transitionSets[sentence.state]
         if debug :
           sentence.printForDebug(sys.stderr)
         action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
@@ -208,7 +230,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         reward_ = rewarding(appliable, sentence, action, missingLinks, rewardFunc)
         reward = torch.FloatTensor([reward_]).to(getDevice())
 
-        #newState = None
+        newState = None
         if appliable :
           applyTransition(strategy, sentence, action, reward_)
           newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
@@ -233,6 +255,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
       if i >= nbExByEpoch :
         break
       sentIndex += 1
-    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
+    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
 ################################################################################
 
diff --git a/Transition.py b/Transition.py
index 5614283c16d1e6bd0cc04c54f48cb8f99988c7a9..2c6eadf5941ed449956b564dabbcabfbe5e44e06 100644
--- a/Transition.py
+++ b/Transition.py
@@ -284,12 +284,14 @@ def applyTag(config, colName, tag) :
 
 ################################################################################
 def applyTransition(strat, config, transition, reward) :
-  movement = strat[transition.name] if transition.name in strat else 0
+  movement = strat[transition.name][0] if transition.name in strat else 0
+  newState = strat[transition.name][1] if transition.name in strat else -1
   transition.apply(config, strat)
   moved = config.moveWordIndex(movement)
   movement = movement if moved else 0
   if len(config.historyPop) > 0 and "BACK" not in transition.name :
     config.historyPop[-1] = (config.historyPop[-1][0], config.historyPop[-1][1], movement, reward)
+  config.state = newState
   return moved
 ################################################################################
 
diff --git a/main.py b/main.py
index 705d375f72e61adaae902cb388905891e28a74ce..c63413a0800d2a9a27ef71cabc8b2d16aab35351 100755
--- a/main.py
+++ b/main.py
@@ -15,8 +15,9 @@ from Transition import Transition
 from Util import isEmpty
 
 ################################################################################
-def printTS(ts, output) :
-  print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in transitionSet], file=output)
+def printTS(transitionSet, output) :
+  for ts in transitionSet :
+    print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in ts], file=output)
 ################################################################################
 
 ################################################################################
@@ -78,38 +79,42 @@ if __name__ == "__main__" :
     args.bootstrap = int(args.bootstrap)
 
   if args.transitions == "eager" :
-    transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
+    transitionSets = [[Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
     args.predicted = "HEAD"
+    args.states = ["parser"]
+    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
   elif args.transitions == "tagparser" :
     tmpDicts = Dicts()
     tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
-    transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
+    transitionSets = [[Transition(elem) for elem in (tagActions+args.ts.split(',')) if len(elem) > 0], [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD,UPOS"
+    args.states = ["tagger", "parser"]
+    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,1), "REDUCE" : (0,1), "TAG" : (0,1)}
   elif args.transitions == "swift" :
-    transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
+    transitionSets = [[Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]]
     args.predictedStr = "HEAD"
+    args.states = ["parser"]
+    strategy = {"RIGHT" : (1,0), "SHIFT" : (1,0), "LEFT" : (0,0), "REDUCE" : (0,0)}
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
 
-  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
-
   if args.mode == "train" :
     args.predicted = set({colName for colName in args.predictedStr.split(',')})
-    json.dump([args.predictedStr, [str(t) for t in transitionSet]], open(args.model+"/transitions.json", "w"))
+    json.dump([args.predictedStr, [[str(t) for t in transitionSet] for transitionSet in transitionSets]], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
-    printTS(transitionSet, sys.stderr)
+    printTS(transitionSets, sys.stderr)
     probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
+    Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
   elif args.mode == "decode" :
     transInfos = json.load(open(args.model+"/transitions.json", "r"))
     transNames = json.load(open(args.model+"/transitions.json", "r"))[1]
     args.predictedStr = transInfos[0]
     args.predicted = set({colName for colName in args.predictedStr.split(',')})
-    transitionSet = [Transition(elem) for elem in transNames]
+    transitionSets = [[Transition(elem) for elem in transNames]for ts in transNames]
     strategy = json.load(open(args.model+"/strategy.json", "r"))
-    printTS(transitionSet, sys.stderr)
-    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.predicted, args.model)
+    printTS(transitionSets, sys.stderr)
+    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.reward, args.predicted, args.model)
   else :
     print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
     exit(1)