diff --git a/Decode.py b/Decode.py
index c2e8df021b9c65536923cd5bc38fcf45d440ed10..f5ac8341017f35e7fcd2f95e0bde530bbd598fdd 100644
--- a/Decode.py
+++ b/Decode.py
@@ -1,7 +1,6 @@
 import random
 import sys
 from Transition import Transition, getMissingLinks, applyTransition
-from Features import extractFeatures
 from Dicts import Dicts
 from Util import getDevice
 import Config
@@ -56,7 +55,7 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
 
   with torch.no_grad():
     while moved :
-      features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
+      features = network.extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
       output = network(features)
       scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
       candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
diff --git a/Dicts.py b/Dicts.py
index 41da2a2c50b193abe59dd599498368494b1cdcdf..010dc27fc594fa117f488c13615c9374c0fdf7e3 100644
--- a/Dicts.py
+++ b/Dicts.py
@@ -9,6 +9,9 @@ class Dicts :
     self.nullToken = "__null__"
     self.noStackToken = "__nostack__"
     self.oobToken = "__oob__"
+    self.noDepLeft = "__nodepleft__"
+    self.noDepRight = "__nodepright__"
+    self.noGov = "__nogov__"
 
   def readConllu(self, filename, colsSet=None) :
     defaultMCD = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC"
@@ -30,7 +33,7 @@ class Dicts :
           targetColumns = list(col2index.keys())
         else :
           targetColumns = list(colsSet)
-        self.dicts = {col : {self.unkToken : 0, self.nullToken : 1} for col in targetColumns}
+        self.dicts = {col : {self.unkToken : 0, self.nullToken : 1, self.noStackToken : 2, self.oobToken : 3, self.noDepLeft : 4, self.noDepRight : 5, self.noGov : 6} for col in targetColumns}
 
       splited = line.split('\t')
       for col in targetColumns :
diff --git a/Features.py b/Features.py
index c152e6318ad28fc7b03bece5804a49c5a3fc7923..6f67c7574077e10d41a823265bc0019a49837d91 100644
--- a/Features.py
+++ b/Features.py
@@ -3,61 +3,79 @@ import sys
 from Util import isEmpty
 
 ################################################################################
-def extractFeatures(dicts, config) :
-  return extractFeaturesPosExtended(dicts, config)
-################################################################################
-
-################################################################################
-def extractFeaturesPos(dicts, config) :
-  bufferWindow = range(-2,2+1)
-  stackWindow = range(0,3+1)
-  totalSize = len(bufferWindow)+len(stackWindow)
-
-  result = torch.zeros(totalSize, dtype=torch.int)
+# Input : b=buffer s=stack .0=governor .x=rightChild#x+1 .-x=leftChild#-x-1
+# Output : list of sentence indexes pointing to elements of featureFunction
+# Special output values :
+# -1 : Out of bounds
+# -2 : Not in stack
+# -3 : No dependent left
+# -4 : No dependent right
+# -5 : No gov
+def extractIndexes(config, featureFunction) :
+  features = featureFunction.split()
+  res = []
+  for feature in features :
+    splited = feature.split('.')
+    obj = splited[0]
+    index = int(splited[1])
+    if obj == "b" :
+      index = config.wordIndex + index
+      if index not in (range(len(config.lines))) :
+        index = -1
+    elif obj == "s" :
+      if index not in range(len(config.stack)) :
+        index = -2
+      else :
+        index = config.stack[-1-index]
+    for depIndex in map(int,splited[2:]) :
+      if index < 0 :
+        break
+      if depIndex == 0 :
+        head = config.getAsFeature(index, "HEAD")
+        if isEmpty(head) :
+          index = -5
+        else :
+          index = int(head)
+        continue
+      if depIndex > 0 :
+        rightChilds = [child for child in config.predChilds[index] if child > index]
+        if depIndex-1 in range(len(rightChilds)) :
+          index = rightChilds[depIndex-1]
+        else :
+          index = -4
+      else :
+        leftChilds = [child for child in config.predChilds[index] if child < index]
+        if abs(depIndex)-1 in range(len(leftChilds)) :
+          index = leftChilds[abs(depIndex)-1]
+        else :
+          index = -3
+    res.append(index)
 
-  insertIndex = 0
-  for i in bufferWindow :
-    index = config.wordIndex + i
-    bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
-    result[insertIndex] = dicts.get("UPOS", bufferPos)
-    insertIndex += 1
-
-  for i in stackWindow :
-    stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
-    result[insertIndex] = dicts.get("UPOS", stackPos)
-    insertIndex += 1
-
-  return result
+  return res
 ################################################################################
 
 ################################################################################
-# For each stack element, add its POS and the POS of its governor
-def extractFeaturesPosExtended(dicts, config) :
-  bufferWindow = range(-2,2+1)
-  stackWindow = range(0,3+1)
-  totalSize = len(bufferWindow)+2*len(stackWindow)
+# For each element of the feature function and for each column, concatenante the dict index
+def extractColsFeatures(dicts, config, featureFunction, cols) :
+  specialValues = {-1 : dicts.oobToken, -2 : dicts.noStackToken, -3 : dicts.noDepLeft, -4 : dicts.noDepRight, -5 : dicts.noGov}
 
+  indexes = extractIndexes(config, featureFunction)
+  totalSize = len(cols)*len(indexes)
   result = torch.zeros(totalSize, dtype=torch.int)
 
   insertIndex = 0
-  for i in bufferWindow :
-    index = config.wordIndex + i
-    bufferPos = dicts.oobToken if index not in range(len(config.lines)) else config.getAsFeature(index, "UPOS")
-    result[insertIndex] = dicts.get("UPOS", bufferPos)
-    insertIndex += 1
-
-  for i in stackWindow :
-    stackPos = dicts.noStackToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "UPOS")
-    stackGovHead = dicts.nullToken if i not in range(len(config.stack)) else config.getAsFeature(config.stack[-1-i], "HEAD")
-    stackGovPos = dicts.nullToken
-    if not isEmpty(stackGovHead) and stackGovHead != dicts.nullToken :
-      stackGovPos = config.getAsFeature(int(stackGovHead), "UPOS")
-    elif stackGovHead == dicts.nullToken :
-      stackGovPos = dicts.noStackToken
-    result[insertIndex] = dicts.get("UPOS", stackPos)
-    insertIndex += 1
-    result[insertIndex] = dicts.get("UPOS", stackGovPos)
-    insertIndex += 1
+  for index in indexes :
+    if index < 0 :
+      for col in cols :
+        result[insertIndex] = dicts.get(col, specialValues[index])
+        insertIndex += 1
+    else :
+      for col in cols :
+        value = config.getAsFeature(index, col)
+        if isEmpty(value) :
+          value = dicts.nullToken
+        result[insertIndex] = dicts.get(col, value)
+        insertIndex += 1
 
   return result
 ################################################################################
diff --git a/Networks.py b/Networks.py
index d1beadec23625290e556e5d3329378d868b60dd3..6f91141de143b073b1b1310ea8467cfa20ed1ab6 100644
--- a/Networks.py
+++ b/Networks.py
@@ -1,19 +1,23 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import Features
 
 ################################################################################
 class BaseNet(nn.Module):
-  def __init__(self, dicts, inputSize, outputSize) :
+  def __init__(self, dicts, outputSize) :
     super().__init__()
     self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
 
+    self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
+    self.columns = ["UPOS"]
+
     self.embSize = 64
-    self.inputSize = inputSize
+    self.inputSize = len(self.columns)*len(self.featureFunction.split())
     self.outputSize = outputSize
     for name in dicts.dicts :
       self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
-    self.fc1 = nn.Linear(inputSize * self.embSize, 1600)
+    self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
     self.fc2 = nn.Linear(1600, outputSize)
     self.dropout = nn.Dropout(0.3)
 
@@ -32,5 +36,9 @@ class BaseNet(nn.Module):
     if type(m) == nn.Linear:
       torch.nn.init.xavier_uniform_(m.weight)
       m.bias.data.fill_(0.01)
+
+  def extractFeatures(self, dicts, config) :
+    return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
+
 ################################################################################
 
diff --git a/Train.py b/Train.py
index 3309201bab0f9095d1b8d8de1373aa6064b0318e..a0420047598332f4fd14694083ae4af125a56208 100644
--- a/Train.py
+++ b/Train.py
@@ -34,7 +34,7 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, boots
 ################################################################################
 
 ################################################################################
-def extractExamples(debug, ts, strat, config, dicts, network=None) :
+def extractExamples(debug, ts, strat, config, dicts, network, dynamic) :
   examples = []
   with torch.no_grad() :
     EOS = Transition("EOS")
@@ -47,12 +47,12 @@ def extractExamples(debug, ts, strat, config, dicts, network=None) :
         break
       best = min([cand[0] for cand in candidates])
       candidateOracle = random.sample([cand for cand in candidates if cand[0] == best], 1)[0][1]
-      features = Features.extractFeatures(dicts, config)
+      features = network.extractFeatures(dicts, config)
       candidate = candidateOracle.name
       if debug :
         config.printForDebug(sys.stderr)
         print(str([[c[0],c[1].name] for c in candidates])+"\n"+("-"*80)+"\n", file=sys.stderr)
-      if network is not None :
+      if dynamic :
         output = network(features.unsqueeze(0).to(getDevice()))
         scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
         candidate = [[cand[0],cand[2]] for cand in scores if cand[1]][0][1]
@@ -95,17 +95,17 @@ def evalModelAndSave(debug, model, dicts, modelDir, devFile, bestLoss, totalLoss
 ################################################################################
 def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, transitionSet, strategy, sentencesOriginal, bootstrapInterval, silent=False) :
   dicts = Dicts()
-  dicts.readConllu(filename, ["FORM", "UPOS"])
+  dicts.readConllu(filename, ["UPOS"])
   dicts.save(modelDir+"/dicts.json")
+  network = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
   examples = []
   sentences = copy.deepcopy(sentencesOriginal)
   print("%s : Starting to extract examples..."%(timeStamp()), file=sys.stderr)
   for config in sentences :
-    examples += extractExamples(debug, transitionSet, strategy, config, dicts)
+    examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, False)
   print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
   examples = torch.stack(examples)
 
-  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
   optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
   lossFct = torch.nn.CrossEntropyLoss()
@@ -117,7 +117,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
       sentences = copy.deepcopy(sentencesOriginal)
       print("%s : Starting to extract dynamic examples..."%(timeStamp()), file=sys.stderr)
       for config in sentences :
-        examples += extractExamples(debug, transitionSet, strategy, config, dicts, network)
+        examples += extractExamples(debug, transitionSet, strategy, config, dicts, network, True)
       print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
       examples = torch.stack(examples)
 
@@ -154,9 +154,13 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   dicts.readConllu(filename, ["FORM", "UPOS"])
   dicts.save(modelDir + "/dicts.json")
 
-  policy_net = None
-  target_net = None
-  optimizer = None
+  policy_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
+  target_net = Networks.BaseNet(dicts, len(transitionSet)).to(getDevice())
+  target_net.load_state_dict(policy_net.state_dict())
+  target_net.eval()
+  policy_net.train()
+  optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
+  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
 
   bestLoss = None
   bestScore = None
@@ -178,16 +182,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
       sentence = sentences[sentIndex]
       sentence.moveWordIndex(0)
-      state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
-
-      if policy_net is None :
-        policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
-        target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
-        target_net.load_state_dict(policy_net.state_dict())
-        target_net.eval()
-        policy_net.train()
-        optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
-        print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
+      state = policy_net.extractFeatures(dicts, sentence).to(getDevice())
 
       while True :
         missingLinks = getMissingLinks(sentence)
@@ -209,7 +204,7 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         newState = None
         if appliable :
           applyTransition(transitionSet, strategy, sentence, action.name)
-          newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
+          newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
 
         if memory is None :
           memory = ReplayMemory(5000, state.numel())
diff --git a/main.py b/main.py
index ef8a1ea5b732954056dd9df9acdd309cd3d52d95..c9f6518c9cb82ba2f76bea15dac677665bc49a09 100755
--- a/main.py
+++ b/main.py
@@ -44,6 +44,9 @@ if __name__ == "__main__" :
   random.seed(args.seed)
   torch.manual_seed(args.seed)
 
+  if args.bootstrap is not None :
+    args.bootstrap = int(args.bootstrap)
+
   if args.mode == "train" :
     Train.trainMode(args.debug, args.corpus, args.type, args.model, int(args.iter), int(args.batchSize), args.dev, args.bootstrap, args.silent)
   elif args.mode == "decode" :