diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 0fc17c65fa325fe2263640f362dbf821567719c7..fd8d1bca540f96c15c9a7ead5306ea9e4e87ac50 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -114,7 +114,7 @@ void Config::print(FILE * dest) const
 
 void Config::printForDebug(FILE * dest) const
 {
-  static constexpr int windowSize = 5;
+  static constexpr int windowSize = 10;
   static constexpr int lettersWindowSize = 40;
   static constexpr int maxWordLength = 7;
 
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 5c8493e562ed9ea2bf50df7bb6fd5ce2983a391d..cdc55145bd7f20cb593e9e9fc82afd02a0564b9c 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -7,7 +7,7 @@
 
 class NeuralNetworkImpl : public torch::nn::Module
 {
-  private : 
+  protected : 
 
   int leftBorder{5};
   int rightBorder{5};
@@ -23,7 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module
   public :
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  std::vector<long> extractContext(Config & config, Dict & dict) const;
+  virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
   int getContextSize() const;
   void setColumns(const std::vector<std::string> & columns);
 };
diff --git a/torch_modules/include/RTLSTMNetwork.hpp b/torch_modules/include/RTLSTMNetwork.hpp
index d30a6e62efe2f3fd76b58bc4c559a458292eeeb0..5d7692523f7661759314b1fb7f1c7a8d7dcbd0b0 100644
--- a/torch_modules/include/RTLSTMNetwork.hpp
+++ b/torch_modules/include/RTLSTMNetwork.hpp
@@ -7,16 +7,23 @@ class RTLSTMNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
+  static constexpr long maxNbChilds{8};
+  static inline std::vector<long> focusedBufferIndexes{0,1,2};
+  static inline std::vector<long> focusedStackIndexes{0,1};
+
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
-  torch::nn::Dropout dropout{nullptr};
-  torch::nn::LSTM lstm{nullptr};
+  torch::nn::LSTM vectorBiLSTM{nullptr};
+  torch::nn::LSTM treeLSTM{nullptr};
+  torch::Tensor S;
+  torch::Tensor nullTree;
 
   public :
 
   RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
   torch::Tensor forward(torch::Tensor input) override;
+  std::vector<long> extractContext(Config & config, Dict & dict) const override;
 };
 
 #endif
diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp
index 6cc8f70bab61d2924ed6c3c022e63303e5319fc0..b59892c7770404e64b63ea653b600088b4aa1f33 100644
--- a/torch_modules/src/RTLSTMNetwork.cpp
+++ b/torch_modules/src/RTLSTMNetwork.cpp
@@ -3,31 +3,176 @@
 RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
 {
   constexpr int embeddingsSize = 30;
-  constexpr int lstmOutputSize = 500;
+  constexpr int lstmOutputSize = 128;
+  constexpr int treeEmbeddingsSize = 256;
   constexpr int hiddenSize = 500;
+
   setLeftBorder(leftBorder);
   setRightBorder(rightBorder);
   setNbStackElements(nbStackElements);
   setColumns({"FORM", "UPOS"});
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
-  linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));
+  linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
-  dropout = register_module("dropout", torch::nn::Dropout(0.3));
-  lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true)));
+  vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
+  treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
+  S = register_parameter("S", torch::randn(treeEmbeddingsSize));
+  nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
 }
 
 torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
 {
-  // input dim = {batch, sequence, embeddings}
-  auto wordsAsEmb = wordEmbeddings(input);
-  if (wordsAsEmb.dim() == 2)
-    wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
-  auto lstmOut = lstm(wordsAsEmb).output;
-  // reshaped dim = {sequence, batch, embeddings}
-  auto reshaped = lstmOut.permute({1,0,2});
-  auto res = linear2(torch::relu(linear1(reshaped[-1])));
-
-  return res;
+  input = input.squeeze();
+  if (input.dim() != 1)
+    util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
+
+  auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
+  auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
+  auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
+  auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
+  auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
+  auto baseEmbeddings = wordEmbeddings(wordIndexes);
+  auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
+  auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
+  std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
+  for (unsigned int i = 0; i < computeOrder.size(0); i++)
+  {
+    int index = computeOrder[i].item<int>();
+    if (index == -1)
+      break;
+    std::vector<torch::Tensor> inputVector;
+    inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
+    for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
+    {
+      int child = childs[index][childIndex].item<int>();
+      if (child == -1)
+        break;
+      inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
+    }
+    auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
+    auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
+    treeRepresentations[index] = lstmOut;
+  }
+
+  std::vector<torch::Tensor> focusedTrees;
+  for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
+  {
+    int index = focusedIndexes[i].item<int>();
+    if (index == -1)
+      focusedTrees.emplace_back(nullTree);
+    else
+      focusedTrees.emplace_back(treeRepresentations[index]);
+  }
+
+  auto representation = torch::cat(focusedTrees, 0);
+  return linear2(torch::relu(linear1(representation)));
+}
+
+std::vector<long> RTLSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  std::vector<long> contextIndexes;
+  std::stack<int> leftContext;
+  for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
+    if (config.isToken(index))
+      leftContext.push(index);
+
+  while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
+    contextIndexes.emplace_back(-1);
+  while (!leftContext.empty())
+  {
+    contextIndexes.emplace_back(leftContext.top());
+    leftContext.pop();
+  }
+
+  for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
+    if (config.isToken(index))
+      contextIndexes.emplace_back(index);
+
+  while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
+    contextIndexes.emplace_back(-1);
+
+  std::map<long, long> indexInContext;
+  for (auto & l : contextIndexes)
+    indexInContext.emplace(std::make_pair(l, indexInContext.size()));
+
+  std::vector<long> headOf;
+  for (auto & l : contextIndexes)
+  {
+    if (l == -1)
+      headOf.push_back(-1);
+    else
+    {
+      auto & head = config.getAsFeature(Config::headColName, l);
+      if (util::isEmpty(head) or head == "_")
+        headOf.push_back(-1);
+      else if  (indexInContext.count(std::stoi(head)))
+        headOf.push_back(std::stoi(head));
+      else
+        headOf.push_back(-1);
+    }
+  }
+
+  std::vector<std::vector<long>> childs(headOf.size());
+  for (unsigned int i = 0; i < headOf.size(); i++)
+    if (headOf[i] != -1)
+      childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
+
+  std::vector<long> treeComputationOrder;
+  std::vector<bool> treeIsComputed(contextIndexes.size(), false);
+
+  std::function<void(long)> depthFirst;
+  depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
+  {
+    if (!indexInContext.count(root))
+      return;
+
+    if (treeIsComputed[indexInContext[root]])
+      return;
+
+    for (auto child : childs[indexInContext[root]])
+      depthFirst(child);
+
+    treeIsComputed[indexInContext[root]] = true;
+    treeComputationOrder.push_back(indexInContext[root]);
+  };
+
+  for (auto & l : focusedBufferIndexes)
+    if (contextIndexes[leftBorder+l] != -1)
+      depthFirst(contextIndexes[leftBorder+l]);
+
+  for (auto & l : focusedStackIndexes)
+    if (config.hasStack(l))
+      depthFirst(config.getStack(l));
+
+  std::vector<long> context;
+  
+  for (auto & c : focusedBufferIndexes)
+    context.push_back(leftBorder+c);
+  for (auto & c : focusedStackIndexes)
+    if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
+      context.push_back(indexInContext[config.getStack(c)]);
+    else
+      context.push_back(-1);
+  for (auto & c : treeComputationOrder)
+    context.push_back(c);
+  while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
+    context.push_back(-1);
+  for (auto & c : childs)
+  {
+    for (unsigned int i = 0; i < maxNbChilds; i++)
+      if (i < c.size())
+        context.push_back(indexInContext[c[i]]);
+      else
+        context.push_back(-1);
+  }
+  for (auto & l : contextIndexes)
+    for (auto & col : columns)
+      if (l == -1)
+        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      else
+        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
+
+  return context;
 }
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 6e889171c5f1fa461f333605446fb8545d287270..5a8c30230b0f56b1aeb98b04879d40cf3d51ab20 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -18,7 +18,7 @@ class Trainer
   DataLoader dataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
-  int batchSize{50};
+  int batchSize{1};
   int nbExamples{0};
 
   public :
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 590504ef6fc8c75d31a6e597188c9648e099ee96..a74078cd8f6e7a567687e248b9f924b3e4b6535a 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -63,7 +63,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 
 float Trainer::epoch(bool printAdvancement)
 {
-  constexpr int printInterval = 2000;
+  constexpr int printInterval = 50;
   int nbExamplesProcessed = 0;
   float totalLoss = 0.0;
   float lossSoFar = 0.0;
@@ -81,6 +81,8 @@ float Trainer::epoch(bool printAdvancement)
     auto labels = batch.target.squeeze();
 
     auto prediction = machine.getClassifier()->getNN()(data);
+    if (prediction.dim() == 1)
+      prediction = prediction.unsqueeze(0);
 
     labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));