diff --git a/Networks.py b/Networks.py
index e894867c443497722a90e125d8989d00cb33170c..b632ccde3fe13263d5cd92d947fe074926ec6c73 100644
--- a/Networks.py
+++ b/Networks.py
@@ -3,6 +3,18 @@ import torch.nn as nn
 import torch.nn.functional as F
 import Features
 
+################################################################################
+def createNetwork(name, dicts, outputSizes, incremental) :
+  if name == "base" :
+    return BaseNet(dicts, outputSizes, incremental)
+  elif name == "lstm" :
+    return LSTMNet(dicts, outputSizes, incremental)
+  elif name == "separated" :
+    return SeparatedNet(dicts, outputSizes, incremental)
+
+  raise Exception("Unknown network name '%s'"%name)
+################################################################################
+
 ################################################################################
 class BaseNet(nn.Module):
   def __init__(self, dicts, outputSizes, incremental) :
@@ -73,6 +85,77 @@ class BaseNet(nn.Module):
 
 ################################################################################
 
+################################################################################
+class SeparatedNet(nn.Module):
+  def __init__(self, dicts, outputSizes, incremental) :
+    super().__init__()
+    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
+
+    self.incremental = incremental
+    self.state = 0
+    self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
+    self.historyNb = 5
+    self.suffixSize = 4
+    self.prefixSize = 4
+    self.columns = ["UPOS", "FORM"]
+
+    self.embSize = 64
+    self.nbTargets = len(self.featureFunction.split())
+    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
+    self.outputSizes = outputSizes
+
+    for i in range(len(outputSizes)) :
+      for name in dicts.dicts :
+        self.add_module("emb_"+name+"_"+str(i), nn.Embedding(len(dicts.dicts[name]), self.embSize))
+      self.add_module("fc1_"+str(i), nn.Linear(self.inputSize * self.embSize, 1600))
+      self.add_module("output_"+str(i), nn.Linear(1600, outputSizes[i]))
+    self.dropout = nn.Dropout(0.3)
+
+    self.apply(self.initWeights)
+
+  def setState(self, state) :
+    self.state = state
+
+  def forward(self, x) :
+    embeddings = []
+    for i in range(len(self.columns)) :
+      embeddings.append(getattr(self, "emb_"+self.columns[i]+"_"+str(self.state))(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
+    y = torch.cat(embeddings,-1).view(x.size(0),-1)
+    curIndex = len(self.columns)*self.nbTargets
+    if self.historyNb > 0 :
+      historyEmb = getattr(self, "emb_HISTORY_"+str(self.state))(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
+      y = torch.cat([y, historyEmb],-1)
+      curIndex = curIndex+self.historyNb
+    if self.prefixSize > 0 :
+      prefixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
+      y = torch.cat([y, prefixEmb],-1)
+      curIndex = curIndex+self.prefixSize
+    if self.suffixSize > 0 :
+      suffixEmb = getattr(self, "emb_LETTER_"+str(self.state))(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
+      y = torch.cat([y, suffixEmb],-1)
+      curIndex = curIndex+self.suffixSize
+    y = self.dropout(y)
+    y = F.relu(self.dropout(getattr(self, "fc1_"+str(self.state))(y)))
+    y = getattr(self, "output_"+str(self.state))(y)
+    return y
+
+  def currentDevice(self) :
+    return self.dummyParam.device
+
+  def initWeights(self,m) :
+    if type(m) == nn.Linear:
+      torch.nn.init.xavier_uniform_(m.weight)
+      m.bias.data.fill_(0.01)
+
+  def extractFeatures(self, dicts, config) :
+    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
+    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
+    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
+    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
+    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
+
+################################################################################
+
 ################################################################################
 class LSTMNet(nn.Module):
   def __init__(self, dicts, outputSizes, incremental) :
@@ -91,7 +174,7 @@ class LSTMNet(nn.Module):
     self.nbInputBase = len(self.featureFunction.split())
     self.nbTargets = self.nbInputBase + self.nbInputLSTM
     self.inputSize = len(self.columns)*self.nbTargets+self.historyNb
-    self.outputSize = outputSize
+    self.outputSizes = outputSizes
     for name in dicts.dicts :
       self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
     self.lstmFeat = nn.LSTM(len(self.columns)*self.embSize, len(self.columns)*int(self.embSize/2), 1, batch_first=True, bidirectional = True)
diff --git a/Train.py b/Train.py
index c6b57bf4ab85bfebebf16b33a5e1a29ae1eab2ec..bedbfbf734e3ec752a7ad1fdf9b593934ca808d9 100644
--- a/Train.py
+++ b/Train.py
@@ -16,15 +16,15 @@ import Config
 from conll18_ud_eval import load_conllu, evaluate
 
 ################################################################################
-def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
+def trainMode(debug, networkName, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
   sentences = Config.readConllu(filename, predicted)
 
   if type == "oracle" :
-    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
+    trainModelOracle(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
     return
 
   if type == "rl":
-    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
+    trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
     return
 
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
@@ -98,7 +98,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
+def trainModelOracle(debug, networkName, modelDir, filename, nbEpochs, batchSize, devFile, transitionSets, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
   transitionNames = {}
@@ -109,7 +109,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
   dicts.addDict("HISTORY", transitionNames)
 
   dicts.save(modelDir+"/dicts.json")
-  network = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
+  network = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
   examples = [[] for _ in transitionSets]
   sentences = copy.deepcopy(sentencesOriginal)
   print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
@@ -185,7 +185,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
+def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devFile, transitionSets, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
 
   memory = None
   dicts = Dicts()
@@ -198,8 +198,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   dicts.addDict("HISTORY", transitionNames)
   dicts.save(modelDir + "/dicts.json")
 
-  policy_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
-  target_net = Networks.BaseNet(dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
+  policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
+  target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental).to(getDevice())
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
   policy_net.train()
diff --git a/main.py b/main.py
index c63413a0800d2a9a27ef71cabc8b2d16aab35351..4912d85b44ce1a5e8fb4c81489fc28b3a2de9401 100755
--- a/main.py
+++ b/main.py
@@ -55,6 +55,8 @@ if __name__ == "__main__" :
     help="Transition set to use (eager | swift | tagparser).")
   parser.add_argument("--ts", default="",
     help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
+  parser.add_argument("--network", default="base",
+    help="Name of the neural network to use (base | lstm | separated).")
   parser.add_argument("--reward", default="A",
     help="Reward function to use (A,B,C,D,E)")
   parser.add_argument("--probaRandom", default="0.6,4,0.1",
@@ -105,7 +107,7 @@ if __name__ == "__main__" :
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     printTS(transitionSets, sys.stderr)
     probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
+    Train.trainMode(args.debug, args.network, args.corpus, args.type, transitionSets, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
   elif args.mode == "decode" :
     transInfos = json.load(open(args.model+"/transitions.json", "r"))
     transNames = json.load(open(args.model+"/transitions.json", "r"))[1]