diff --git a/Config.py b/Config.py
index a5278a5f8b000165d672f020ad6162638bdca078..3862951c569f35bc789151869ba9eac2d2ac72de 100644
--- a/Config.py
+++ b/Config.py
@@ -3,13 +3,13 @@ import sys
 
 ################################################################################
 class Config :
-  def __init__(self, col2index, index2col) :
+  def __init__(self, col2index, index2col, predicted) :
     self.lines = []
     self.goldChilds = []
     self.predChilds = []
     self.col2index = col2index
     self.index2col = index2col
-    self.predicted = set({"HEAD", "DEPREL"})
+    self.predicted = predicted
     self.wordIndex = 0
     self.maxWordIndex = 0 #To keep a track of the max value, in case of backtrack
     self.stack = []
@@ -130,7 +130,7 @@ class Config :
 ################################################################################
   
 ################################################################################
-def readConllu(filename) :
+def readConllu(filename, predicted) :
   configs = []
   defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
   col2index, index2col = readMCD(defaultMCD)
@@ -156,7 +156,7 @@ def readConllu(filename) :
 
       configs[-1].comments = comments
 
-      configs.append(Config(col2index, index2col))
+      configs.append(Config(col2index, index2col, predicted))
       currentIndex = 0
       id2index = {}
       comments = []
@@ -167,7 +167,7 @@ def readConllu(filename) :
       continue
 
     if len(configs) == 0 :
-      configs.append(Config(col2index, index2col))
+      configs.append(Config(col2index, index2col, predicted))
       currentIndex = 0
       id2index = {}
 
diff --git a/Decode.py b/Decode.py
index 9bdda8456c919495e496863a877f2b993efd5e22..b0fcfb2378e000c464623386ea8ebdaef4fbd394 100644
--- a/Decode.py
+++ b/Decode.py
@@ -83,9 +83,9 @@ def decodeModel(ts, strat, config, network, dicts, debug, rewardFunc) :
 ################################################################################
 
 ################################################################################
-def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, modelDir=None, network=None, dicts=None, output=sys.stdout) :
+def decodeMode(debug, filename, type, transitionSet, strategy, rewardFunc, predicted, modelDir=None, network=None, dicts=None, output=sys.stdout) :
 
-  sentences = Config.readConllu(filename)
+  sentences = Config.readConllu(filename, predicted)
 
   if type in ["random", "oracle"] :
     decodeFunc = oracleDecode if type == "oracle" else randomDecode
diff --git a/Dicts.py b/Dicts.py
index f6787ec1ab945ed1066c04f0790960e2b8f6bab2..d7e0c9d6b7a2b4977d6ee4f757036b1e9c9ec865 100644
--- a/Dicts.py
+++ b/Dicts.py
@@ -63,6 +63,11 @@ class Dicts :
       return self.dicts[col][value.lower()][0]
     return self.dicts[col][self.unkToken][0]
 
+  def getElementsOf(self, col) :
+    if col not in self.dicts :
+      raise Exception("Unknown dict name %s"%col)
+    return self.dicts[col].keys()
+
   def save(self, target) :
     json.dump(self.dicts, open(target, "w"))
 
diff --git a/Train.py b/Train.py
index 34ab29333f56d4cc1be880a5d1bfe97cebeea502..c34f6ad1d82ef7924b3ad6182b78d4df1927cce5 100644
--- a/Train.py
+++ b/Train.py
@@ -16,15 +16,15 @@ import Config
 from conll18_ud_eval import load_conllu, evaluate
 
 ################################################################################
-def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, silent=False) :
-  sentences = Config.readConllu(filename)
+def trainMode(debug, filename, type, transitionSet, strategy, modelDir, nbIter, batchSize, devFile, bootstrapInterval, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) :
+  sentences = Config.readConllu(filename, predicted)
 
   if type == "oracle" :
-    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, silent)
+    trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent)
     return
 
   if type == "rl":
-    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, silent)
+    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, incremental, rewardFunc, lr, gamma, probas, predicted, silent)
     return
 
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
@@ -69,19 +69,21 @@ def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
 ################################################################################
 
 ################################################################################
-def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc) :
+def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted) :
+  col2metric = {"HEAD" : "UAS", "DEPREL" : "LAS", "UPOS" : "UPOS", "FEATS" : "UFeats"}
+
   devScore = ""
   saved = True if bestLoss is None else totalLoss < bestLoss
   bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
   if devFile is not None :
     outFilename = modelDir+"/predicted_dev.conllu"
-    Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, modelDir, model, dicts, open(outFilename, "w"))
+    Decode.decodeMode(debug, devFile, "model", ts, strat, rewardFunc, predicted, modelDir, model, dicts, open(outFilename, "w"))
     res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
-    UAS = res["UAS"][0].f1
-    score = UAS
+    scores = [res[col2metric[col]][0].f1 for col in predicted]
+    score = sum(scores)/len(scores)
     saved = True if bestScore is None else score > bestScore
     bestScore = score if bestScore is None else max(bestScore, score)
-    devScore = ", Dev : UAS=%.2f"%(UAS)
+    devScore = ", Dev : "+" ".join(["%s=%.2f"%(col2metric[list(predicted)[i]], scores[i]) for i in range(len(predicted))])
   if saved :
     torch.save(model, modelDir+"/network.pt")
   for out in [sys.stderr, open(modelDir+"/train.log", "w" if epoch == 1 else "a")] :
@@ -91,7 +93,7 @@ def evalModelAndSave(debug, model, ts, strat, dicts, modelDir, devFile, bestLoss
 ################################################################################
 
 ################################################################################
-def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, silent=False) :
+def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, incremental, rewardFunc, lr, predicted, silent=False) :
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM","UPOS"], 2)
   dicts.addDict("HISTORY", {**{str(t) : (transitionSet.index(t),0) for t in transitionSet}, **{dicts.nullToken : (len(transitionSet),0)}})
@@ -142,11 +144,11 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
       optimizer.step()
       totalLoss += float(loss)
 
-    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc)
+    bestLoss, bestScore = evalModelAndSave(debug, network, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbEpochs, incremental, rewardFunc, predicted)
 ################################################################################
 
 ################################################################################
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, silent=False) :
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, incremental, rewardFunc, lr, gamma, probas, predicted, silent=False) :
 
   memory = None
   dicts = Dicts()
@@ -225,6 +227,6 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
       if i >= nbExByEpoch :
         break
       sentIndex += 1
-    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc)
+    bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSet, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
 ################################################################################
 
diff --git a/Transition.py b/Transition.py
index 92f9e9d6e7ef649daeda64b983e536942ea67620..5614283c16d1e6bd0cc04c54f48cb8f99988c7a9 100644
--- a/Transition.py
+++ b/Transition.py
@@ -8,14 +8,17 @@ class Transition :
   def __init__(self, name) :
     splited = name.split()
     self.name = splited[0]
-    self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if len(splited) == 1 else int(splited[1])
-    if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS"] :
+    self.size = (1 if self.name in ["LEFT","RIGHT"] else None) if (len(splited) == 1 or splited[0] == "TAG") else int(splited[1])
+    self.colName = None
+    self.argument = None
+    if len(splited) == 3 :
+      self.colName = splited[1]
+      self.argument = splited[2]
+    if not self.name in ["SHIFT","REDUCE","LEFT","RIGHT","BACK","EOS","TAG"] :
       raise(Exception("'%s' is not a valid transition type."%name))
 
   def __str__(self) :
-    if self.size is None :
-      return self.name
-    return "%s %d"%(self.name, self.size)
+    return " ".join(map(str,[e for e in [self.name, self.size, self.colName, self.argument] if e is not None]))
 
   def __lt__(self, other) :
     return str(self) < str(other)
@@ -33,6 +36,8 @@ class Transition :
       data = applyReduce(config)
     elif self.name == "EOS" :
       applyEOS(config)
+    elif self.name == "TAG" :
+      applyTag(config, self.colName, self.argument)
     elif "BACK" in self.name :
       config.historyHistory.add(str([t[0].name for t in config.historyPop]))
       applyBack(config, strategy, self.size)
@@ -45,21 +50,32 @@ class Transition :
 
   def appliable(self, config) :
     if self.name == "RIGHT" :
+      for colName in config.predicted :
+        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
+          return False
       if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.wordIndex, "HEAD")) and not linkCauseCycle(config, config.stack[-self.size], config.wordIndex)) :
         return False
       orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
       return len(orphansInStack) == 0
     if self.name == "LEFT" :
+      for colName in config.predicted :
+        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
+          return False
       if not (len(config.stack) >= self.size and isEmpty(config.getAsFeature(config.stack[-self.size], "HEAD")) and not linkCauseCycle(config, config.wordIndex, config.stack[-self.size])) :
         return False
       orphansInStack = [s for s in config.stack[-self.size+1:] if isEmpty(config.getAsFeature(s, "HEAD"))] if self.size > 1 else []
       return len(orphansInStack) == 0
     if self.name == "SHIFT" :
+      for colName in config.predicted :
+        if colName not in ["HEAD","DEPREL"] and isEmpty(config.getAsFeature(config.wordIndex, colName)) :
+          return False
       return config.wordIndex < len(config.lines) - 1
     if self.name == "REDUCE" :
       return len(config.stack) > 0 and not isEmpty(config.getAsFeature(config.stack[-1], "HEAD"))
     if self.name == "EOS" :
       return config.wordIndex == len(config.lines) - 1
+    if self.name == "TAG" :
+      return isEmpty(config.getAsFeature(config.wordIndex, self.colName))
     if "BACK" in self.name :
       if len(config.historyPop) < self.size :
         return False
@@ -77,6 +93,8 @@ class Transition :
       return scoreOracleShift(config, missingLinks)
     if self.name == "REDUCE" :
       return scoreOracleReduce(config, missingLinks)
+    if self.name == "TAG" :
+      return 0 if self.argument == config.getGold(config.wordIndex, self.colName) else 1
     if "BACK" in self.name :
       return 1
 
@@ -165,6 +183,8 @@ def applyBack(config, strategy, size) :
       applyBackShift(config)
     elif trans.name == "REDUCE" :
       applyBackReduce(config, data)
+    elif trans.name == "TAG" :
+      applyBackTag(config, trans.colName)
     else :
       print("ERROR : trying to apply BACK to '%s'"%trans.name, file=sys.stderr)
       exit(1)
@@ -198,6 +218,11 @@ def applyBackReduce(config, data) :
   config.stack.append(data)
 ################################################################################
 
+################################################################################
+def applyBackTag(config, colName) :
+  config.set(config.wordIndex, colName, "")
+################################################################################
+
 ################################################################################
 def applyRight(config, size=1) :
   config.set(config.wordIndex, "HEAD", config.stack[-size])
@@ -252,6 +277,11 @@ def applyEOS(config) :
     config.predChilds[rootIndex].append(index)
 ################################################################################
 
+################################################################################
+def applyTag(config, colName, tag) :
+  config.set(config.wordIndex, colName, tag)
+################################################################################
+
 ################################################################################
 def applyTransition(strat, config, transition, reward) :
   movement = strat[transition.name] if transition.name in strat else 0
diff --git a/main.py b/main.py
index 89c4e783e973f924ab84de5e6a265c4ee6865fdf..874ce0f58f2e43d5d2f6faa1294f88c1105614e3 100755
--- a/main.py
+++ b/main.py
@@ -10,12 +10,13 @@ import json
 import Util
 import Train
 import Decode
+from Dicts import Dicts
 from Transition import Transition
-
+from Util import isEmpty
 
 ################################################################################
 def printTS(ts, output) :
-  print("Transition Set :", [trans.name + ("" if trans.size is None else " "+str(trans.size)) for trans in transitionSet], file=output)
+  print("Transition Set :", [" ".join(map(str,[e for e in [trans.name, trans.size, trans.colName, trans.argument] if e is not None])) for trans in transitionSet], file=output)
 ################################################################################
 
 ################################################################################
@@ -76,25 +77,35 @@ if __name__ == "__main__" :
 
   if args.transitions == "eager" :
     transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+args.ts.split(',')) if len(elem) > 0]
+    args.predicted = "HEAD"
+  elif args.transitions == "tagparser" :
+    tmpDicts = Dicts()
+    tmpDicts.readConllu(args.corpus, ["UPOS"], 0)
+    tagActions = ["TAG UPOS %s"%p for p in tmpDicts.getElementsOf("UPOS") if "__" not in p and not isEmpty(p)]
+    transitionSet = [Transition(elem) for elem in (["SHIFT","REDUCE","LEFT","RIGHT"]+tagActions+args.ts.split(',')) if len(elem) > 0]
+    args.predicted = "HEAD,UPOS"
   elif args.transitions == "swift" :
     transitionSet = [Transition(elem) for elem in (["SHIFT"]+["LEFT "+str(n) for n in range(1,6)]+["RIGHT "+str(n) for n in range(1,6)]+args.ts.split(',')) if len(elem) > 0]
+    args.predicted = "HEAD"
   else :
     raise Exception("Unknown transition set '%s'"%args.transitions)
 
-  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0}
+  strategy = {"RIGHT" : 1, "SHIFT" : 1, "LEFT" : 0, "REDUCE" : 0, "TAG" : 0}
+
+  args.predicted = set({colName for colName in args.predicted.split(',')})
 
   if args.mode == "train" :
     json.dump([str(t) for t in transitionSet], open(args.model+"/transitions.json", "w"))
     json.dump(strategy, open(args.model+"/strategy.json", "w"))
     printTS(transitionSet, sys.stderr)
     probas = [list(map(float, args.probaRandom.split(','))), list(map(float, args.probaOracle.split(',')))]
-    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.silent)
+    Train.trainMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.incr, args.reward, float(args.lr), float(args.gamma), probas, args.predicted, args.silent)
   elif args.mode == "decode" :
     transNames = json.load(open(args.model+"/transitions.json", "r"))
     transitionSet = [Transition(elem) for elem in transNames]
     strategy = json.load(open(args.model+"/strategy.json", "r"))
     printTS(transitionSet, sys.stderr)
-    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.model)
+    Decode.decodeMode(args.debug, args.corpus, args.type, transitionSet, strategy, args.reward, args.predicted, args.model)
   else :
     print("ERROR : unknown mode '%s'"%args.mode, file=sys.stderr)
     exit(1)