diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index aeb6ac87cc665db30c5cbb3ec4bc17fa99989214..6925568e06b8829cdf130f7a44886561ca7f10e7 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -73,10 +73,12 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
   std::vector<int> bufferContext, stackContext;
   std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
   std::vector<int> focusedBuffer, focusedStack;
+  std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
   std::vector<int> maxNbElements;
+  std::vector<int> treeEmbeddingNbElems;
   std::vector<std::pair<int, float>> mlp;
   int rawInputLeftWindow, rawInputRightWindow;
-  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
+  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
   bool bilstm;
   float lstmDropout;
 
@@ -229,6 +231,37 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns));
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingBuffer.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingStack.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingNbElems.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
+        {
+          treeEmbeddingSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
+
+  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize));
 }
 
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp
index 6eb069b99d308982327142c15acc85c6cb50c042..436a082a06121a2c62f50da0b5f5ef4b79b99ba8 100644
--- a/torch_modules/include/DepthLayerTreeEmbedding.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp
@@ -9,17 +9,15 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
 {
   private :
 
-  std::vector<std::string> columns{"DEPREL"};
-  std::vector<int> focusedBuffer{0};
-  std::vector<int> focusedStack{0};
-  std::string firstElem{"__special_DepthLayerTreeEmbeddingImpl__"};
+  std::vector<int> maxElemPerDepth;
+  std::vector<std::string> columns;
+  std::vector<int> focusedBuffer;
+  std::vector<int> focusedStack;
   std::vector<LSTM> depthLstm;
-  int maxDepth;
-  int maxElemPerDepth;
 
   public :
 
-  DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 9cb051d45f15252c304d84174db1965f357eb721..f0b58dc099d512e73ce95c9d9c808cf206006cc5 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -29,7 +29,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns);
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index 3f1926d72dd2f8dad1109ca15b2eed32e957ff43..aa4c7aef388105493ce3315d7761acb96b7693ba 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -1,9 +1,10 @@
 #include "DepthLayerTreeEmbedding.hpp"
 
-DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
+DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) :
+  maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
 {
-  for (int i = 0; i < maxDepth; i++)
-    depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(embeddingsSize, outEmbeddingsSize, options)));
+  for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
+    depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
 }
 
 torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
@@ -12,9 +13,13 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
 
   std::vector<torch::Tensor> outputs;
 
-  for (unsigned int i = 0; i < depthLstm.size(); i++)
-    for (unsigned int j = 0; j < focusedBuffer.size()+focusedStack.size(); j++)
-      outputs.emplace_back(depthLstm[i](input.narrow(1,i*(focusedBuffer.size()+focusedStack.size())*columns.size()*maxElemPerDepth + j*maxElemPerDepth, maxElemPerDepth)));
+  int offset = 0;
+  for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
+    for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
+    {
+      outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
+      offset += maxElemPerDepth[depth]*columns.size();
+    }
 
   return torch::cat(outputs, 1);
 }
@@ -23,15 +28,18 @@ std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize()
 {
   std::size_t outputSize = 0;
 
-  for (auto & lstm : depthLstm)
-    outputSize += lstm->getOutputSize(maxElemPerDepth);
+  for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
+    outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]);
 
-  return outputSize;
+  return outputSize*(focusedBuffer.size()+focusedStack.size());
 }
 
 std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
 {
-  return (focusedBuffer.size()+focusedStack.size())*columns.size()*maxDepth*maxElemPerDepth;
+  int inputSize = 0;
+  for (int maxElem : maxElemPerDepth)
+    inputSize += (focusedBuffer.size()+focusedStack.size())*maxElem*columns.size();
+  return inputSize;
 }
 
 void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
@@ -48,11 +56,27 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> &
       focusedIndexes.emplace_back(-1);
 
   for (auto & contextElement : context)
-  {
     for (auto index : focusedIndexes)
     {
+      std::vector<std::string> childs{std::to_string(index)};
 
+      for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
+      {
+        std::vector<std::string> newChilds;
+        for (auto & child : childs)
+          if (config.has(Config::childsColName, std::stoi(child), 0))
+          {
+            auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|');
+            newChilds.insert(newChilds.end(), val.begin(), val.end());
+          }
+        childs = newChilds;
+        for (int i = 0; i < maxElemPerDepth[depth]; i++)
+          for (auto & col : columns)
+            if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0))
+              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i]))));
+            else
+              contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      }
     }
-  }
 }
 
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 2a9b1b4dca054e89a22322cdf7301998e720ee37..7a25024b67d90bd2f5f2acde0839baa308e4bc74 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -26,7 +26,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
   if (!treeEmbeddingColumns.empty())
   {
     hasTreeEmbedding = true;
-    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(1,3,embeddingsSize,128,treeEmbeddingColumns,focusedBufferIndexes,focusedStackIndexes,lstmOptionsAll));
+    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptionsAll));
     treeEmbedding->setFirstInputIndex(currentInputSize);
     currentOutputSize += treeEmbedding->getOutputSize();
     currentInputSize += treeEmbedding->getInputSize();
@@ -89,13 +89,15 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
   context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
 
   contextLSTM->addToContext(context, dict, config);
+
   if (hasRawInputLSTM)
     rawInputLSTM->addToContext(context, dict, config);
-  fmt::print(stderr, "before {}\n", context.back().size());
+
   if (hasTreeEmbedding)
     treeEmbedding->addToContext(context, dict, config);
-  fmt::print(stderr, "after {}\n", context.back().size());
+
   splitTransLSTM->addToContext(context, dict, config);
+
   for (auto & lstm : focusedLstms)
     lstm->addToContext(context, dict, config);