diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp
index b59892c7770404e64b63ea653b600088b4aa1f33..75ded3a96f15d532566716226d056969441d6287 100644
--- a/torch_modules/src/RTLSTMNetwork.cpp
+++ b/torch_modules/src/RTLSTMNetwork.cpp
@@ -23,49 +23,59 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
 
 torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
 {
-  input = input.squeeze();
-  if (input.dim() != 1)
-    util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
-
-  auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
-  auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
-  auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
-  auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
-  auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
+  if (input.dim() == 1)
+    input = input.unsqueeze(0);
+
+  auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
+  auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1);
+  auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1));
+  auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
+  auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1));
   auto baseEmbeddings = wordEmbeddings(wordIndexes);
-  auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
-  auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
-  std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
-  for (unsigned int i = 0; i < computeOrder.size(0); i++)
+  auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
+  auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
+
+  std::vector<std::map<int, torch::Tensor>> treeRepresentations;
+  for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
   {
-    int index = computeOrder[i].item<int>();
-    if (index == -1)
-      break;
-    std::vector<torch::Tensor> inputVector;
-    inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
-    for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
+    treeRepresentations.emplace_back();
+    for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
     {
-      int child = childs[index][childIndex].item<int>();
-      if (child == -1)
+      int index = computeOrder[batch][i].item<int>();
+      if (index == -1)
         break;
-      inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
+      std::vector<torch::Tensor> inputVector;
+      inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
+      for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
+      {
+        int child = childs[batch][index][childIndex].item<int>();
+        if (child == -1)
+          break;
+        inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
+      }
+      auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
+      auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
+      treeRepresentations[batch][index] = lstmOut;
     }
-    auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
-    auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
-    treeRepresentations[index] = lstmOut;
   }
 
   std::vector<torch::Tensor> focusedTrees;
-  for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
+  std::vector<torch::Tensor> representations;
+  for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
   {
-    int index = focusedIndexes[i].item<int>();
-    if (index == -1)
-      focusedTrees.emplace_back(nullTree);
-    else
-      focusedTrees.emplace_back(treeRepresentations[index]);
+    focusedTrees.clear();
+    for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
+    {
+      int index = focusedIndexes[batch][i].item<int>();
+      if (index == -1)
+        focusedTrees.emplace_back(nullTree);
+      else
+        focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
+    }
+    representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
   }
 
-  auto representation = torch::cat(focusedTrees, 0);
+  auto representation = torch::cat(representations, 0);
   return linear2(torch::relu(linear1(representation)));
 }