diff --git a/Config.py b/Config.py
index 3862951c569f35bc789151869ba9eac2d2ac72de..8b1e84b5bce905033d9b62b543784a3b3fae0b69 100644
--- a/Config.py
+++ b/Config.py
@@ -118,11 +118,11 @@ class Config :
       toPrint = []
       for colIndex in range(len(self.lines[index])) :
         value = str(self.getAsFeature(index, self.index2col[colIndex]))
-        if value == "" :
+        if value == "" or value == '_':
           value = "_"
-        elif self.index2col[colIndex] == "HEAD" and value != "-1":
+        elif self.index2col[colIndex] == "HEAD" and (value != "-1" and self.getAsFeature(index, "DEPREL") != 'root'):
           value = self.getAsFeature(int(value), "ID")
-        elif self.index2col[colIndex] == "HEAD" and value == "-1":
+        elif self.index2col[colIndex] == "HEAD" and (value == "-1" or self.getAsFeature(index, "DEPREL") == 'root'):
           value = "0"
         toPrint.append(value)
       print("\t".join(toPrint), file=output)
diff --git a/Decode.py b/Decode.py
index b0fcfb2378e000c464623386ea8ebdaef4fbd394..dfdb2ed8cd7be072368247891a1c6cbe90c262e2 100644
--- a/Decode.py
+++ b/Decode.py
@@ -77,7 +77,8 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
       reward = rewarding(True, config, candidate, missingLinks, rewardFunc)
       moved = applyTransition(strat, config, candidate, reward)
 
-  EOS.apply(config, strat)
+  if len(strat) > 1 :
+    EOS.apply(config, strat)
 
   network.to(currentDevice)
 ################################################################################
diff --git a/Networks.py b/Networks.py
index 8dfd4b6edb7e9422f06a87bb336e54fc2c61522f..5a035467d0142ad66466fa2e8e7816f686038da9 100644
--- a/Networks.py
+++ b/Networks.py
@@ -3,6 +3,12 @@ import torch.nn as nn
 import torch.nn.functional as F
 import Features
 
+def get_network(mlp, dicts, outputSize, incremntal):
+  if mlp == 'POSTagNet':
+    return POSTagNet(dicts, outputSize, incremntal)
+  elif mlp == 'BaseNet':
+    return BaseNet(dicts, outputSize, incremntal)
+
 ################################################################################
 class BaseNet(nn.Module):
   def __init__(self, dicts, outputSize, incremental) :
@@ -134,3 +140,67 @@ class LSTMNet(nn.Module):
     return torch.cat([colsValuesBase, colsValuesLSTM, historyValues])
 ################################################################################
 
+################################################################################
+class POSTagNet(nn.Module):
+  def __init__(self, dicts, outputSize, incremental) :
+    super().__init__()
+    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
+
+    self.incremental = incremental
+    self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2"
+    self.historyNb = 5
+    self.suffixSize = 4
+    self.prefixSize = 4
+    self.columns = ["UPOS", "FORM"]
+
+    self.embSize = 64
+    self.nbTargets = len(self.featureFunction.split())
+    self.inputSize = len(self.columns)*self.nbTargets+self.historyNb+self.suffixSize+self.prefixSize
+    self.outputSize = outputSize
+    for name in dicts.dicts :
+      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
+    self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
+    self.fc2 = nn.Linear(1600, outputSize)
+    self.dropout = nn.Dropout(0.3)
+
+    self.apply(self.initWeights)
+
+  def forward(self, x) :
+    embeddings = []
+    for i in range(len(self.columns)) :
+      embeddings.append(getattr(self, "emb_"+self.columns[i])(x[...,i*self.nbTargets:(i+1)*self.nbTargets]))
+    y = torch.cat(embeddings,-1).view(x.size(0),-1)
+    curIndex = len(self.columns)*self.nbTargets
+    if self.historyNb > 0 :
+      historyEmb = getattr(self, "emb_HISTORY")(x[...,curIndex:curIndex+self.historyNb]).view(x.size(0),-1)
+      y = torch.cat([y, historyEmb],-1)
+      curIndex = curIndex+self.historyNb
+    if self.prefixSize > 0 :
+      prefixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.prefixSize]).view(x.size(0),-1)
+      y = torch.cat([y, prefixEmb],-1)
+      curIndex = curIndex+self.prefixSize
+    if self.suffixSize > 0 :
+      suffixEmb = getattr(self, "emb_LETTER")(x[...,curIndex:curIndex+self.suffixSize]).view(x.size(0),-1)
+      y = torch.cat([y, suffixEmb],-1)
+      curIndex = curIndex+self.suffixSize
+    y = self.dropout(y)
+    y = F.relu(self.dropout(self.fc1(y)))
+    y = self.fc2(y)
+    return y
+
+  def currentDevice(self) :
+    return self.dummyParam.device
+
+  def initWeights(self,m) :
+    if type(m) == nn.Linear:
+      torch.nn.init.xavier_uniform_(m.weight)
+      m.bias.data.fill_(0.01)
+
+  def extractFeatures(self, dicts, config) :
+    colsValues = Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns, self.incremental)
+    historyValues = Features.extractHistoryFeatures(dicts, config, self.historyNb)
+    prefixValues = Features.extractPrefixFeatures(dicts, config, self.prefixSize)
+    suffixValues = Features.extractSuffixFeatures(dicts, config, self.suffixSize)
+    return torch.cat([colsValues, historyValues, prefixValues, suffixValues])
+
+################################################################################
\ No newline at end of file
diff --git a/Train.py b/Train.py
index 8b948d22ab8ce610d3b0535e8f4b4e2ed666c7fe..f0f0c5cd92e40b13280cb2d3a3a9bf6d6620104d 100644
--- a/Train.py
+++ b/Train.py
@@ -16,15 +16,15 @@ import Config
 from conll18_ud_eval import load_conllu, evaluate
 
 ################################################################################
-def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
+def trainMode(debug, filename, type, transitionSet, strategy, mlp, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
   sentences = Config.readConllu(filename, predicted)
 
   if type == "oracle" :
-    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
+    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
     return
 
   if type == "rl":
-    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
+    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentences, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent)
     return
 
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
@@ -63,7 +63,8 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
 
       moved = applyTransition(strat, config, candidate, None)
 
-    EOS.apply(config, strat)
+    if len(strat) > 1:
+      EOS.apply(config, strat)
 
   return examples
 ################################################################################
@@ -94,12 +95,12 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
+def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS","LETTER"], 2)
   dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
   dicts.save(modelDir+"/dicts.json")
-  network = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
+  network = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
   examples = []
   sentences = copy.deepcopy(sentencesOriginal)
   print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
@@ -149,7 +150,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, mlp, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, countBreak, predicted, silent=False) :
 
   memory = None
   dicts = Dicts()
@@ -157,8 +158,8 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
   dicts.save(modelDir + "/dicts.json")
 
-  policy_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
-  target_net = Networks.BaseNet(dicts, len(transitionSet), incremental).to(getDevice())
+  policy_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
+  target_net = Networks.get_network(mlp, dicts, len(transitionSet), incremental).to(getDevice())
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
   policy_net.train()
diff --git a/main.py b/main.py
index c82affe6e8b6ab97d6643b85c4a44d2fc7e9e086..2308dcc93237973f699cbeab3431f44c08eea514 100755
--- a/main.py
+++ b/main.py
@@ -51,7 +51,7 @@ if __name__ == "__main__" :
   parser.add_argument("--silent", "-s", default=False, action="store_true",
     help="Don't print advancement infos.")
   parser.add_argument("--transitions", default="eager",
-    help="Transition set to use (eager | swift | tagparser).")
+    help="Transition set to use (eager | swift | tagparser | tag).")
   parser.add_argument("--ts", default="",
     help="Comma separated list of supplementary transitions. Example \"BACK 1,BACK 2\"")
   parser.add_argument("--reward", default="A",
@@ -86,13 +86,24 @@ if __name__ == "__main__" :
     tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
     transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
     args.predicted = "HEAD,UPOS"
+  elif args.transitions == "tag":
+    tmpDicts = Dicts()
+    tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
+    tagActions = ["TAG UPOS %s" % p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
+    transitionSet = [Transition(elem) for elem in (tagActions + args.ts.split(',')) if len(elem) > 0]
+    args.predicted = "UPOS"
   elif args.transitions == "swift" :
     transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
     args.predicted = "HEAD"
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
 
-  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
+  if args.transitions == "tag":
+    strategy = {"TAG": 1}
+    mlp = 'POSTagNet'
+  else:
+    strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
+    mlp = 'BaseNet'
 
   args.predicted = set({colName for colName in args.predicted.split(',')})
 
@@ -101,7 +112,7 @@ if __name__ == "__main__" :
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     printTS(transitionSet, sys.stderr)
     probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
+    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, mlp, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, int(args.countBreak), args.predicted, args.silent)
   elif args.mode == "decode" :
     transNames = json.load(open(args.model+"/transitions.json", "r"))
     transitionSet = [Transition(elem) for elem in transNames]