diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 54dccff6694b0bb2f541302318a93c8d7bcaa1c3..20a505639a96547459148c2752af5602b0b06e08 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -1,8 +1,5 @@
 #include "Classifier.hpp"
 #include "util.hpp"
-#include "ConcatWordsNetwork.hpp"
-#include "RLTNetwork.hpp"
-#include "CNNNetwork.hpp"
 #include "LSTMNetwork.hpp"
 #include "RandomNetwork.hpp"
 
@@ -40,45 +37,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
         this->nn.reset(new RandomNetworkImpl(this->transitionSet->size()));
       }
     },
-    {
-      std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
-      "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
-      [this,topology](auto sm)
-      {
-        std::vector<int> bufferContext, stackContext;
-        for (auto s : util::split(sm.str(1), ','))
-          bufferContext.emplace_back(std::stoi(s));
-        for (auto s : util::split(sm.str(2), ','))
-          stackContext.emplace_back(std::stoi(s));
-        this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext));
-      }
-    },
-    {
-      std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
-      "CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
-      [this,topology](auto sm)
-      {
-        std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
-        std::vector<std::string> focusedColumns, columns;
-        for (auto s : util::split(sm.str(2), ','))
-          bufferContext.emplace_back(std::stoi(s));
-        for (auto s : util::split(sm.str(3), ','))
-          stackContext.emplace_back(std::stoi(s));
-        for (auto s : util::split(sm.str(4), ','))
-          columns.emplace_back(s);
-        for (auto s : util::split(sm.str(5), ','))
-          focusedBuffer.push_back(std::stoi(s));
-        for (auto s : util::split(sm.str(6), ','))
-          focusedStack.push_back(std::stoi(s));
-        for (auto s : util::split(sm.str(7), ','))
-          focusedColumns.emplace_back(s);
-        for (auto s : util::split(sm.str(8), ','))
-          maxNbElements.push_back(std::stoi(s));
-        if (focusedColumns.size() != maxNbElements.size())
-          util::myThrow("focusedColumns.size() != maxNbElements.size()");
-        this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
-      }
-    },
     {
       std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
       "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
@@ -105,14 +63,6 @@ void Classifier::initNeuralNetwork(const std::string & topology)
         this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
       }
     },
-    {
-      std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
-      "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
-      [this,topology](auto sm)
-      {
-        this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3))));
-      }
-    },
   };
 
   std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp
index 509be856b83e40431aef0c138747a70b62847201..66c405ca13bdb6ddf31ddfe34e2da42788fddc79 100644
--- a/torch_modules/include/CNN.hpp
+++ b/torch_modules/include/CNN.hpp
@@ -17,7 +17,7 @@ class CNNImpl : public torch::nn::Module
 
   CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
   torch::Tensor forward(torch::Tensor input);
-  int getOutputSize();
+  std::size_t getOutputSize();
 
 };
 TORCH_MODULE(CNN);
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
deleted file mode 100644
index 6fb985ac482f88856cc8b2fa6bc07f4971e79546..0000000000000000000000000000000000000000
--- a/torch_modules/include/CNNNetwork.hpp
+++ /dev/null
@@ -1,35 +0,0 @@
-#ifndef CNNNETWORK__H
-#define CNNNETWORK__H
-
-#include "NeuralNetwork.hpp"
-#include "CNN.hpp"
-
-class CNNNetworkImpl : public NeuralNetworkImpl
-{
-  private :
-
-  int unknownValueThreshold;
-  std::vector<std::string> focusedColumns;
-  std::vector<int> maxNbElements;
-  int leftWindowRawInput;
-  int rightWindowRawInput;
-  int rawInputSize;
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Dropout embeddingsDropout{nullptr};
-  torch::nn::Dropout cnnDropout{nullptr};
-  torch::nn::Dropout hiddenDropout{nullptr};
-  torch::nn::Linear linear1{nullptr};
-  torch::nn::Linear linear2{nullptr};
-  CNN contextCNN{nullptr};
-  CNN rawInputCNN{nullptr};
-  std::vector<CNN> cnns;
-
-  public :
-
-  CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
-  torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
-};
-
-#endif
diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp
deleted file mode 100644
index b2c9691dd3a062cf600e99ce1f7ec9ca4d478643..0000000000000000000000000000000000000000
--- a/torch_modules/include/ConcatWordsNetwork.hpp
+++ /dev/null
@@ -1,21 +0,0 @@
-#ifndef CONCATWORDSNETWORK__H
-#define CONCATWORDSNETWORK__H
-
-#include "NeuralNetwork.hpp"
-
-class ConcatWordsNetworkImpl : public NeuralNetworkImpl
-{
-  private :
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Linear linear1{nullptr};
-  torch::nn::Linear linear2{nullptr};
-  torch::nn::Dropout dropout{nullptr};
-
-  public :
-
-  ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext);
-  torch::Tensor forward(torch::Tensor input) override;
-};
-
-#endif
diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..136029cd33d2c3a1825ce23d8902c860242eecf8
--- /dev/null
+++ b/torch_modules/include/ContextLSTM.hpp
@@ -0,0 +1,30 @@
+#ifndef CONTEXTLSTM__H
+#define CONTEXTLSTM__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "LSTM.hpp"
+
+class ContextLSTMImpl : public torch::nn::Module, public Submodule
+{
+  private :
+
+  LSTM lstm{nullptr};
+  std::vector<std::string> columns;
+  std::vector<int> bufferContext;
+  std::vector<int> stackContext;
+  int unknownValueThreshold;
+  std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"};
+
+  public :
+
+  ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+};
+TORCH_MODULE(ContextLSTM);
+
+#endif
+
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..d471e6b4d57933c56b4e4f2f7261963414c00ff2
--- /dev/null
+++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp
@@ -0,0 +1,25 @@
+#ifndef DEPTHLAYERTREEEMBEDDING__H
+#define DEPTHLAYERTREEEMBEDDING__H
+
+#include <torch/torch.h>
+#include "fmt/core.h"
+#include "LSTM.hpp"
+
+class DepthLayerTreeEmbeddingImpl : public torch::nn::Module
+{
+  private :
+
+  std::vector<LSTM> depthLstm;
+  int maxDepth;
+  int maxElemPerDepth;
+
+  public :
+
+  DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth);
+  torch::Tensor forward(torch::Tensor input);
+  int getOutputSize();
+};
+TORCH_MODULE(DepthLayerTreeEmbedding);
+
+#endif
+
diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..6ea836a041017fb1fdf6725506b6eb1f561bdb99
--- /dev/null
+++ b/torch_modules/include/FocusedColumnLSTM.hpp
@@ -0,0 +1,28 @@
+#ifndef FOCUSEDCOLUMNLSTM__H
+#define FOCUSEDCOLUMNLSTM__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "LSTM.hpp"
+
+class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
+{
+  private :
+
+  LSTM lstm{nullptr};
+  std::vector<int> focusedBuffer, focusedStack;
+  std::string column;
+  int maxNbElements;
+
+  public :
+
+  FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+};
+TORCH_MODULE(FocusedColumnLSTM);
+
+#endif
+
diff --git a/torch_modules/include/LSTM.hpp b/torch_modules/include/LSTM.hpp
index eb06c4508aeb733027b7050aadd0ea1933095696..c45cb9fff41aa6f9fc0883bcf7e37c9db5e3a8e9 100644
--- a/torch_modules/include/LSTM.hpp
+++ b/torch_modules/include/LSTM.hpp
@@ -6,6 +6,10 @@
 
 class LSTMImpl : public torch::nn::Module
 {
+  public :
+
+  using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
+
   private :
 
   torch::nn::LSTM lstm{nullptr};
@@ -13,7 +17,7 @@ class LSTMImpl : public torch::nn::Module
 
   public :
 
-  LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options);
+  LSTMImpl(int inputSize, int outputSize, LSTMOptions options);
   torch::Tensor forward(torch::Tensor input);
   int getOutputSize(int sequenceLength);
 };
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 860f40701d8d1d9e572b12c733996ac5f25c3b71..5762ad1ad5c32211bf190d05d6c259a08d793cb6 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -2,29 +2,28 @@
 #define LSTMNETWORK__H
 
 #include "NeuralNetwork.hpp"
-#include "LSTM.hpp"
+#include "ContextLSTM.hpp"
+#include "RawInputLSTM.hpp"
+#include "SplitTransLSTM.hpp"
+#include "FocusedColumnLSTM.hpp"
 
 class LSTMNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  int unknownValueThreshold;
-  std::vector<std::string> focusedColumns;
-  std::vector<int> maxNbElements;
-  int leftWindowRawInput;
-  int rightWindowRawInput;
-  int rawInputSize;
-
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Dropout lstmDropout{nullptr};
   torch::nn::Dropout hiddenDropout{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
-  LSTM contextLSTM{nullptr};
-  LSTM rawInputLSTM{nullptr};
-  LSTM splitTransLSTM{nullptr};
-  std::vector<LSTM> lstms;
+
+  ContextLSTM contextLSTM{nullptr};
+  RawInputLSTM rawInputLSTM{nullptr};
+  SplitTransLSTM splitTransLSTM{nullptr};
+  std::vector<FocusedColumnLSTM> focusedLstms;
+
+  bool hasRawInputLSTM{false};
 
   public :
 
diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp
deleted file mode 100644
index 90bde50aea779f7d6d3188d98f7058053c0ec8e2..0000000000000000000000000000000000000000
--- a/torch_modules/include/MLP.hpp
+++ /dev/null
@@ -1,14 +0,0 @@
-#ifndef MLP__H
-#define MLP__H
-
-#include <torch/torch.h>
-
-class MLPImpl : torch::nn::Module
-{
-  public :
-
-  MLPImpl(const std::string & topology);
-};
-TORCH_MODULE(MLP);
-
-#endif
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index ffc0052cfd063e27693691c2e028f506fbf04049..be25c873978d61edc57ee014695e2110b8cb189b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -15,31 +15,10 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   static constexpr int maxNbEmbeddings = 150000;
 
-  std::vector<std::string> columns{"FORM"};
-  std::vector<int> bufferContext{-3,-2,-1,0,1};
-  std::vector<int> stackContext{};
-  std::vector<int> bufferFocused{};
-  std::vector<int> stackFocused{};
-
-  protected :
-
-  void setBufferContext(const std::vector<int> & bufferContext);
-  void setStackContext(const std::vector<int> & stackContext);
-  void setBufferFocused(const std::vector<int> & bufferFocused);
-  void setStackFocused(const std::vector<int> & stackFocused);
-
   public :
 
   virtual torch::Tensor forward(torch::Tensor input) = 0;
-  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
-  std::vector<long> extractContextIndexes(const Config & config) const;
-  std::vector<long> extractFocusedIndexes(const Config & config) const;
-  int getContextSize() const;
-  void setColumns(const std::vector<std::string> & columns);
-  void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const;
-  void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const;
-  void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const;
-  void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const;
+  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
 };
 TORCH_MODULE(NeuralNetwork);
 
diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp
deleted file mode 100644
index 71b8aa55dce4740c1f08fd8ecb343826dedd8989..0000000000000000000000000000000000000000
--- a/torch_modules/include/RLTNetwork.hpp
+++ /dev/null
@@ -1,31 +0,0 @@
-#ifndef RLTNETWORK__H
-#define RLTNETWORK__H
-
-#include "NeuralNetwork.hpp"
-
-class RLTNetworkImpl : public NeuralNetworkImpl
-{
-  private :
-
-  static constexpr long maxNbChilds{8};
-  static inline std::vector<long> focusedBufferIndexes{0,1,2};
-  static inline std::vector<long> focusedStackIndexes{0,1};
-
-  int leftBorder, rightBorder;
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Linear linear1{nullptr};
-  torch::nn::Linear linear2{nullptr};
-  torch::nn::LSTM vectorBiLSTM{nullptr};
-  torch::nn::LSTM treeLSTM{nullptr};
-  torch::Tensor S;
-  torch::Tensor nullTree;
-
-  public :
-
-  RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
-  torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
-};
-
-#endif
diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp
index e715cc4b920ed5aefe3258244f3ea0d8d888f938..8f58d7b30859a9cd4130fc506f83fc9c51bce34e 100644
--- a/torch_modules/include/RandomNetwork.hpp
+++ b/torch_modules/include/RandomNetwork.hpp
@@ -13,6 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
 
   RandomNetworkImpl(long outputSize);
   torch::Tensor forward(torch::Tensor input) override;
+  std::vector<std::vector<long>> extractContext(Config &, Dict &) const override;
 };
 
 #endif
diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..db17d6f0014e615474a462635844e3d5251f3fb0
--- /dev/null
+++ b/torch_modules/include/RawInputLSTM.hpp
@@ -0,0 +1,26 @@
+#ifndef RAWINPUTLSTM__H
+#define RAWINPUTLSTM__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "LSTM.hpp"
+
+class RawInputLSTMImpl : public torch::nn::Module, public Submodule
+{
+  private :
+
+  LSTM lstm{nullptr};
+  int leftWindow, rightWindow;
+
+  public :
+
+  RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+};
+TORCH_MODULE(RawInputLSTM);
+
+#endif
+
diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..f90c0edcfedd699d7d9d18626db8d80cd40d385e
--- /dev/null
+++ b/torch_modules/include/SplitTransLSTM.hpp
@@ -0,0 +1,26 @@
+#ifndef SPLITTRANSLSTM__H
+#define SPLITTRANSLSTM__H
+
+#include <torch/torch.h>
+#include "Submodule.hpp"
+#include "LSTM.hpp"
+
+class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
+{
+  private :
+
+  LSTM lstm{nullptr};
+  int maxNbTrans;
+
+  public :
+
+  SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  torch::Tensor forward(torch::Tensor input);
+  std::size_t getOutputSize() override;
+  std::size_t getInputSize() override;
+  void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
+};
+TORCH_MODULE(SplitTransLSTM);
+
+#endif
+
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..437bbfa4e82ca29fb35b33924ba1fc3c16cbb126
--- /dev/null
+++ b/torch_modules/include/Submodule.hpp
@@ -0,0 +1,22 @@
+#ifndef SUBMODULE__H
+#define SUBMODULE__H
+
+#include "Dict.hpp"
+#include "Config.hpp"
+
+class Submodule
+{
+  protected :
+
+  std::size_t firstInputIndex{0};
+
+  public :
+
+  void setFirstInputIndex(std::size_t firstInputIndex);
+  virtual std::size_t getOutputSize() = 0;
+  virtual std::size_t getInputSize() = 0;
+  virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
+};
+
+#endif
+
diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp
index cc67d6067d9d60eabdadf8ead1412cef0ece0115..dbc3797b12f7ce6e3e16e8d93e0959e0a3bdb804 100644
--- a/torch_modules/src/CNN.cpp
+++ b/torch_modules/src/CNN.cpp
@@ -26,7 +26,7 @@ torch::Tensor CNNImpl::forward(torch::Tensor input)
   return cnnOut;
 }
 
-int CNNImpl::getOutputSize()
+std::size_t CNNImpl::getOutputSize()
 {
   return windowSizes.size()*nbFilters;
 }
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
deleted file mode 100644
index c03b347ee55f71d3e130be22f4770d3000904847..0000000000000000000000000000000000000000
--- a/torch_modules/src/CNNNetwork.cpp
+++ /dev/null
@@ -1,187 +0,0 @@
-#include "CNNNetwork.hpp"
-
-CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
-{
-  constexpr int embeddingsSize = 64;
-  constexpr int hiddenSize = 1024;
-  constexpr int nbFiltersContext = 512;
-  constexpr int nbFiltersFocused = 64;
-
-  setBufferContext(bufferContext);
-  setStackContext(stackContext);
-  setColumns(columns);
-  setBufferFocused(focusedBufferIndexes);
-  setStackFocused(focusedStackIndexes);
-
-  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
-  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
-    rawInputSize = 0;
-  else
-    rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
-  int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
-  cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
-  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
-  contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
-  int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
-  for (auto & col : focusedColumns)
-  {
-    std::vector<int> windows{2,3,4};
-    cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize)));
-    totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size());
-  }
-  linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize));
-  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
-}
-
-torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
-{
-  if (input.dim() == 1)
-    input = input.unsqueeze(0);
-
-  auto embeddings = embeddingsDropout(wordEmbeddings(input));
-
-  auto context = embeddings.narrow(1, rawInputSize, getContextSize());
-  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
-
-  auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
-
-  std::vector<torch::Tensor> cnnOutputs;
-
-  if (rawInputSize != 0)
-  {
-    auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
-    cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
-  }
-
-  auto curIndex = 0;
-  for (unsigned int i = 0; i < focusedColumns.size(); i++)
-  {
-    long nbElements = maxNbElements[i];
-    for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
-    {
-      auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
-      curIndex += nbElements;
-      cnnOutputs.emplace_back(cnns[i](cnnInput));
-    }
-  }
-
-  cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
-
-  auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
-
-  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
-}
-
-std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
-{
-  if (dict.size() >= maxNbEmbeddings)
-    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
-
-  std::vector<long> contextIndexes = extractContextIndexes(config);
-  std::vector<std::vector<long>> context;
-  context.emplace_back();
-
-  if (rawInputSize > 0)
-  {
-    for (int i = 0; i < leftWindowRawInput; i++)
-      if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
-      else
-        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-
-    for (int i = 0; i <= rightWindowRawInput; i++)
-      if (config.hasCharacter(config.getCharacterIndex()+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
-      else
-        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-  }
-
-  for (auto index : contextIndexes)
-    for (auto & col : columns)
-      if (index == -1)
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-      else
-      {
-        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
-
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
-
-        if (is_training())
-          if (col == "FORM" || col == "LEMMA")
-            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
-            {
-              context.emplace_back(context.back());
-              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
-            }
-      }
-
-  std::vector<long> focusedIndexes = extractFocusedIndexes(config);
-
-  for (auto & contextElement : context)
-    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
-    {
-      auto & col = focusedColumns[colIndex];
-
-      for (auto index : focusedIndexes)
-      {
-        if (index == -1)
-        {
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-          continue;
-        }
-
-        std::vector<std::string> elements;
-        if (col == "FORM")
-        {
-          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
-
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            if (i < (int)asUtf8.size())
-              elements.emplace_back(fmt::format("{}", asUtf8[i]));
-            else
-              elements.emplace_back(Dict::nullValueStr);
-        }
-        else if (col == "FEATS")
-        {
-          auto splited = util::split(config.getAsFeature(col, index).get(), '|');
-
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            if (i < (int)splited.size())
-              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
-            else
-              elements.emplace_back(Dict::nullValueStr);
-        }
-        else if (col == "ID")
-        {
-          if (config.isTokenPredicted(index))
-            elements.emplace_back("ID(TOKEN)");
-          else if (config.isMultiwordPredicted(index))
-            elements.emplace_back("ID(MULTIWORD)");
-          else if (config.isEmptyNodePredicted(index))
-            elements.emplace_back("ID(EMPTYNODE)");
-        }
-        else
-        {
-          elements.emplace_back(config.getAsFeature(col, index));
-        }
-
-        if ((int)elements.size() != maxNbElements[colIndex])
-          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
-
-        for (auto & element : elements)
-          contextElement.emplace_back(dict.getIndexOrInsert(element));
-      }
-    }
-
-  if (!is_training() && context.size() > 1)
-    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
-
-  return context;
-}
-
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
deleted file mode 100644
index 2331d59a61e1f8ee4bd7679bc009125c61e70775..0000000000000000000000000000000000000000
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "ConcatWordsNetwork.hpp"
-
-ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
-{
-  constexpr int embeddingsSize = 64;
-  constexpr int hiddenSize = 500;
-
-  setBufferContext(bufferContext);
-  setStackContext(stackContext);
-  setColumns({"FORM", "UPOS"});
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
-  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
-  dropout = register_module("dropout", torch::nn::Dropout(0.3));
-}
-
-torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
-{
-  if (input.dim() == 1)
-    input = input.unsqueeze(0);
-  auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
-  return linear2(torch::relu(linear1(wordsAsEmb)));
-}
-
diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..95daa696df15c70797332f0d938c5646111e418d
--- /dev/null
+++ b/torch_modules/src/ContextLSTM.cpp
@@ -0,0 +1,62 @@
+#include "ContextLSTM.hpp"
+
+ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold)
+{
+  lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options));
+}
+
+std::size_t ContextLSTMImpl::getOutputSize()
+{
+  return lstm->getOutputSize(bufferContext.size()+stackContext.size());
+}
+
+std::size_t ContextLSTMImpl::getInputSize()
+{
+  return columns.size()*(bufferContext.size()+stackContext.size());
+}
+
+void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+{
+  std::vector<long> contextIndexes;
+
+  for (int index : bufferContext)
+    contextIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : stackContext)
+    if (config.hasStack(index))
+      contextIndexes.emplace_back(config.getStack(index));
+    else
+      contextIndexes.emplace_back(-1);
+
+  for (auto index : contextIndexes)
+    for (auto & col : columns)
+      if (index == -1)
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      else
+      {
+        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
+
+        for (auto & contextElement : context)
+          contextElement.push_back(dictIndex);
+
+        if (is_training())
+          for (auto & targetCol : unknownValueColumns)
+            if (col == targetCol)
+              if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
+              {
+                context.emplace_back(context.back());
+                context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
+              }
+      }
+}
+
+torch::Tensor ContextLSTMImpl::forward(torch::Tensor input)
+{
+  auto context = input.narrow(1, firstInputIndex, getInputSize());
+
+  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
+
+  return lstm(context);
+}
+
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d53a04abce60d08ff775a820500409ee860b76d9
--- /dev/null
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -0,0 +1,17 @@
+#include "DepthLayerTreeEmbedding.hpp"
+
+DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth)
+{
+
+}
+
+torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
+{
+
+}
+
+int DepthLayerTreeEmbeddingImpl::getOutputSize()
+{
+
+}
+
diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9b5f52fba36b9af59ef4634c2bdc3728babc621d
--- /dev/null
+++ b/torch_modules/src/FocusedColumnLSTM.cpp
@@ -0,0 +1,94 @@
+#include "FocusedColumnLSTM.hpp"
+
+FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
+{
+  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+}
+
+torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input)
+{
+  std::vector<torch::Tensor> outputs;
+  for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
+    outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
+
+  return torch::cat(outputs, 1);
+}
+
+std::size_t FocusedColumnLSTMImpl::getOutputSize()
+{
+  return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements);
+}
+
+std::size_t FocusedColumnLSTMImpl::getInputSize()
+{
+  return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
+}
+
+void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+{
+  std::vector<long> focusedIndexes;
+
+  for (int index : focusedBuffer)
+    focusedIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : focusedStack)
+    if (config.hasStack(index))
+      focusedIndexes.emplace_back(config.getStack(index));
+    else
+      focusedIndexes.emplace_back(-1);
+
+  for (auto & contextElement : context)
+  {
+    for (auto index : focusedIndexes)
+    {
+      if (index == -1)
+      {
+        for (int i = 0; i < maxNbElements; i++)
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        continue;
+      }
+
+      std::vector<std::string> elements;
+      if (column == "FORM")
+      {
+        auto asUtf8 = util::splitAsUtf8(config.getAsFeature(column, index).get());
+
+        for (int i = 0; i < maxNbElements; i++)
+          if (i < (int)asUtf8.size())
+            elements.emplace_back(fmt::format("{}", asUtf8[i]));
+          else
+            elements.emplace_back(Dict::nullValueStr);
+      }
+      else if (column == "FEATS")
+      {
+        auto splited = util::split(config.getAsFeature(column, index).get(), '|');
+
+        for (int i = 0; i < maxNbElements; i++)
+          if (i < (int)splited.size())
+            elements.emplace_back(fmt::format("FEATS({})", splited[i]));
+          else
+            elements.emplace_back(Dict::nullValueStr);
+      }
+      else if (column == "ID")
+      {
+        if (config.isTokenPredicted(index))
+          elements.emplace_back("ID(TOKEN)");
+        else if (config.isMultiwordPredicted(index))
+          elements.emplace_back("ID(MULTIWORD)");
+        else if (config.isEmptyNodePredicted(index))
+          elements.emplace_back("ID(EMPTYNODE)");
+      }
+      else
+      {
+        elements.emplace_back(config.getAsFeature(column, index));
+      }
+
+      if ((int)elements.size() != maxNbElements)
+        util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
+
+      for (auto & element : elements)
+        contextElement.emplace_back(dict.getIndexOrInsert(element));
+    }
+  }
+}
+
diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp
index 58b102a290b76330f5618e264d11579df39f5a2e..b8f8e7f7dfb2d568f69ccb4062764d85a7b930a8 100644
--- a/torch_modules/src/LSTM.cpp
+++ b/torch_modules/src/LSTM.cpp
@@ -1,6 +1,6 @@
 #include "LSTM.hpp"
 
-LSTMImpl::LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options) : outputAll(std::get<4>(options))
+LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options))
 {
   auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
     .batch_first(std::get<0>(options))
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 268b504591207e98a319ee34729cfea4d207cfec..430b2267ce1ae3d4dba987751749e1b22a527c76 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput)
 {
   constexpr int embeddingsSize = 256;
   constexpr int hiddenSize = 8192;
@@ -8,41 +8,45 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
   constexpr int focusedLSTMSize = 256;
   constexpr int rawInputLSTMSize = 32;
 
-  std::tuple<bool,bool,int,float,bool> lstmOptions{true,true,2,0.3,false};
+  LSTMImpl::LSTMOptions lstmOptions{true,true,2,0.3,false};
   auto lstmOptionsAll = lstmOptions;
   std::get<4>(lstmOptionsAll) = true;
 
-  setBufferContext(bufferContext);
-  setStackContext(stackContext);
-  setColumns(columns);
-  setBufferFocused(focusedBufferIndexes);
-  setStackFocused(focusedStackIndexes);
-
-  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
-  int rawInputLSTMOutSize = 0;
-  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
-    rawInputSize = 0;
-  else
+  int currentOutputSize = embeddingsSize;
+  int currentInputSize = 1;
+
+  contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
+  contextLSTM->setFirstInputIndex(currentInputSize);
+  currentOutputSize += contextLSTM->getOutputSize();
+  currentInputSize += contextLSTM->getInputSize();
+
+  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
   {
-    rawInputLSTM = register_module("rawInputLSTM", LSTM(embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
-    rawInputLSTMOutSize = rawInputLSTM->getOutputSize(rawInputSize);
+    hasRawInputLSTM = true;
+    rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
+    rawInputLSTM->setFirstInputIndex(currentInputSize);
+    currentOutputSize += rawInputLSTM->getOutputSize();
+    currentInputSize += rawInputLSTM->getInputSize();
   }
 
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
-  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
-  contextLSTM = register_module("contextLSTM", LSTM(columns.size()*embeddingsSize, contextLSTMSize, lstmOptions));
-  splitTransLSTM = register_module("splitTransLSTM", LSTM(embeddingsSize, embeddingsSize, lstmOptionsAll));
-
-  int totalLSTMOutputSize = rawInputLSTMOutSize + contextLSTM->getOutputSize(getContextSize()) + splitTransLSTM->getOutputSize(Config::maxNbAppliableSplitTransitions);
+  splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, embeddingsSize, lstmOptionsAll));
+  splitTransLSTM->setFirstInputIndex(currentInputSize);
+  currentOutputSize += splitTransLSTM->getOutputSize();
+  currentInputSize += splitTransLSTM->getInputSize();
 
   for (unsigned int i = 0; i < focusedColumns.size(); i++)
   {
-    lstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), LSTM(embeddingsSize, focusedLSTMSize, lstmOptions)));
-    totalLSTMOutputSize += (bufferFocused.size()+stackFocused.size())*lstms.back()->getOutputSize(maxNbElements[i]);
+    focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
+    focusedLstms.back()->setFirstInputIndex(currentInputSize);
+    currentOutputSize += focusedLstms.back()->getOutputSize();
+    currentInputSize += focusedLstms.back()->getInputSize();
   }
 
-  linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize));
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
+  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
+  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
+
+  linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
 }
 
@@ -53,40 +57,19 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
 
   auto embeddings = embeddingsDropout(wordEmbeddings(input));
 
-  auto state = embeddings.narrow(1, 0, 1).squeeze(1);
-
-  auto splitTrans = embeddings.narrow(1, 1, Config::maxNbAppliableSplitTransitions);
+  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
 
-  auto context = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize, getContextSize());
+  outputs.emplace_back(contextLSTM(embeddings));
 
-  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
+  if (hasRawInputLSTM)
+    outputs.emplace_back(rawInputLSTM(embeddings));
 
-  auto elementsEmbeddings = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(1+splitTrans.size(1)+rawInputSize+context.size(1)));
+  outputs.emplace_back(splitTransLSTM(embeddings));
 
-  std::vector<torch::Tensor> lstmOutputs;
+  for (auto & lstm : focusedLstms)
+    outputs.emplace_back(lstm(embeddings));
 
-  lstmOutputs.emplace_back(state);
-
-  if (rawInputSize != 0)
-  {
-    auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize);
-    lstmOutputs.emplace_back(rawInputLSTM(rawLetters));
-  }
-
-  lstmOutputs.emplace_back(splitTransLSTM(splitTrans));
-
-  auto curIndex = 0;
-  for (unsigned int i = 0; i < focusedColumns.size(); i++)
-    for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
-    {
-      auto lstmInput = elementsEmbeddings.narrow(1, curIndex, maxNbElements[i]);
-      curIndex += maxNbElements[i];
-      lstmOutputs.emplace_back(lstms[i](lstmInput));
-    }
-
-  lstmOutputs.emplace_back(contextLSTM(context));
-
-  auto totalInput = torch::cat(lstmOutputs, 1);
+  auto totalInput = torch::cat(outputs, 1);
 
   return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
 }
@@ -101,13 +84,12 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
 
   context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
 
-  addAppliableSplitTransitions(context, dict, config);
-
-  addRawInput(context, dict, config, leftWindowRawInput, rightWindowRawInput);
-
-  addContext(context, dict, config, extractContextIndexes(config), unknownValueThreshold, {"FORM","LEMMA"});
-
-  addFocused(context, dict, config, extractFocusedIndexes(config), focusedColumns, maxNbElements);
+  contextLSTM->addToContext(context, dict, config);
+  if (hasRawInputLSTM)
+    rawInputLSTM->addToContext(context, dict, config);
+  splitTransLSTM->addToContext(context, dict, config);
+  for (auto & lstm : focusedLstms)
+    lstm->addToContext(context, dict, config);
 
   if (!is_training() && context.size() > 1)
     util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp
deleted file mode 100644
index 182e880a6df2c60b131bc71af463e47d901b64b8..0000000000000000000000000000000000000000
--- a/torch_modules/src/MLP.cpp
+++ /dev/null
@@ -1,8 +0,0 @@
-#include "MLP.hpp"
-#include <regex>
-
-MLPImpl::MLPImpl(const std::string & topology)
-{
-  
-}
-
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 4a123b261922227d352fd26fdd1f4ff7e49c93e4..02e8a191bfb4b2bc718b6e815a266bec252fb24b 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -1,195 +1,4 @@
 #include "NeuralNetwork.hpp"
-#include "Transition.hpp"
 
 torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
 
-std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
-{
-  std::vector<long> context;
-
-  for (int index : bufferContext)
-    context.emplace_back(config.getRelativeWordIndex(index));
-
-  for (int index : stackContext)
-    if (config.hasStack(index))
-      context.emplace_back(config.getStack(index));
-    else
-      context.emplace_back(-1);
-
-  return context;
-}
-
-std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const
-{
-  std::vector<long> context;
-
-  for (int index : bufferFocused)
-    context.emplace_back(config.getRelativeWordIndex(index));
-
-  for (int index : stackFocused)
-    if (config.hasStack(index))
-      context.emplace_back(config.getStack(index));
-    else
-      context.emplace_back(-1);
-
-  return context;
-}
-
-std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
-{
-  std::vector<long> indexes = extractContextIndexes(config);
-  std::vector<long> context;
-
-  for (auto & col : columns)
-    for (auto index : indexes)
-      if (index == -1)
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-      else
-        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
-
-  return {context};
-}
-
-int NeuralNetworkImpl::getContextSize() const
-{
-  return columns.size()*(bufferContext.size()+stackContext.size());
-}
-
-void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext)
-{
-  this->bufferContext = bufferContext;
-}
-
-void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext)
-{
-  this->stackContext = stackContext;
-}
-
-void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused)
-{
-  this->bufferFocused = bufferFocused;
-}
-
-void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused)
-{
-  this->stackFocused = stackFocused;
-}
-
-void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
-{
-  this->columns = columns;
-}
-
-void NeuralNetworkImpl::addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
-{
-  auto & splitTransitions = config.getAppliableSplitTransitions();
-  for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++)
-    if (i < (int)splitTransitions.size())
-      context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
-    else
-      context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-}
-
-void NeuralNetworkImpl::addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const
-{
-  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
-    return;
-
-  for (int i = 0; i < leftWindowRawInput; i++)
-    if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
-    else
-      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-
-  for (int i = 0; i <= rightWindowRawInput; i++)
-    if (config.hasCharacter(config.getCharacterIndex()+i))
-      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
-    else
-      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-}
-
-void NeuralNetworkImpl::addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const
-{
-  for (auto index : contextIndexes)
-    for (auto & col : columns)
-      if (index == -1)
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-      else
-      {
-        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
-
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
-
-        if (is_training())
-          for (auto & targetCol : unknownValueColumns)
-            if (col == targetCol)
-              if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
-              {
-                context.emplace_back(context.back());
-                context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
-              }
-      }
-}
-
-void NeuralNetworkImpl::addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const
-{
-  for (auto & contextElement : context)
-    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
-    {
-      auto & col = focusedColumns[colIndex];
-
-      for (auto index : focusedIndexes)
-      {
-        if (index == -1)
-        {
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-          continue;
-        }
-
-        std::vector<std::string> elements;
-        if (col == "FORM")
-        {
-          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
-
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            if (i < (int)asUtf8.size())
-              elements.emplace_back(fmt::format("{}", asUtf8[i]));
-            else
-              elements.emplace_back(Dict::nullValueStr);
-        }
-        else if (col == "FEATS")
-        {
-          auto splited = util::split(config.getAsFeature(col, index).get(), '|');
-
-          for (int i = 0; i < maxNbElements[colIndex]; i++)
-            if (i < (int)splited.size())
-              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
-            else
-              elements.emplace_back(Dict::nullValueStr);
-        }
-        else if (col == "ID")
-        {
-          if (config.isTokenPredicted(index))
-            elements.emplace_back("ID(TOKEN)");
-          else if (config.isMultiwordPredicted(index))
-            elements.emplace_back("ID(MULTIWORD)");
-          else if (config.isEmptyNodePredicted(index))
-            elements.emplace_back("ID(EMPTYNODE)");
-        }
-        else
-        {
-          elements.emplace_back(config.getAsFeature(col, index));
-        }
-
-        if ((int)elements.size() != maxNbElements[colIndex])
-          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
-
-        for (auto & element : elements)
-          contextElement.emplace_back(dict.getIndexOrInsert(element));
-      }
-    }
-}
-
diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp
deleted file mode 100644
index e4f3fc215aef10c1cfbaa05371b094e980adf6cc..0000000000000000000000000000000000000000
--- a/torch_modules/src/RLTNetwork.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-#include "RLTNetwork.hpp"
-
-RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
-{
-  constexpr int embeddingsSize = 30;
-  constexpr int lstmOutputSize = 128;
-  constexpr int treeEmbeddingsSize = 256;
-  constexpr int hiddenSize = 500;
-
-  //TODO gerer ces context
-  this->leftBorder = leftBorder;
-  this->rightBorder = rightBorder;
-  setBufferContext({});
-  setStackContext({});
-  setColumns({"FORM", "UPOS"});
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
-  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
-  vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
-  treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false)));
-  S = register_parameter("S", torch::randn(treeEmbeddingsSize));
-  nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize));
-}
-
-torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
-{
-  if (input.dim() == 1)
-    input = input.unsqueeze(0);
-
-  auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
-  auto computeOrder = input.narrow(1, focusedIndexes.size(1), getContextSize()/columns.size());
-  auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(getContextSize()/columns.size()));
-  auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
-  auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), getContextSize());
-  auto baseEmbeddings = wordEmbeddings(wordIndexes);
-  auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
-  auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
-
-  std::vector<std::map<int, torch::Tensor>> treeRepresentations;
-  for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
-  {
-    treeRepresentations.emplace_back();
-    for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
-    {
-      int index = computeOrder[batch][i].item<int>();
-      if (index == -1)
-        break;
-      std::vector<torch::Tensor> inputVector;
-      inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
-      for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
-      {
-        int child = childs[batch][index][childIndex].item<int>();
-        if (child == -1)
-          break;
-        inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
-      }
-      auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
-      auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
-      treeRepresentations[batch][index] = lstmOut;
-    }
-  }
-
-  std::vector<torch::Tensor> focusedTrees;
-  std::vector<torch::Tensor> representations;
-  for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
-  {
-    focusedTrees.clear();
-    for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
-    {
-      int index = focusedIndexes[batch][i].item<int>();
-      if (index == -1)
-        focusedTrees.emplace_back(nullTree);
-      else
-        focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
-    }
-    representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
-  }
-
-  auto representation = torch::cat(representations, 0);
-  return linear2(torch::relu(linear1(representation)));
-}
-
-std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
-{
-  std::vector<long> contextIndexes;
-  std::stack<int> leftContext;
-  for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
-    if (config.isToken(index))
-      leftContext.push(index);
-
-  while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size())
-    contextIndexes.emplace_back(-1);
-  while (!leftContext.empty())
-  {
-    contextIndexes.emplace_back(leftContext.top());
-    leftContext.pop();
-  }
-
-  for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index)
-    if (config.isToken(index))
-      contextIndexes.emplace_back(index);
-
-  while ((int)contextIndexes.size() < leftBorder+rightBorder+1)
-    contextIndexes.emplace_back(-1);
-
-  std::map<long, long> indexInContext;
-  for (auto & l : contextIndexes)
-    indexInContext.emplace(std::make_pair(l, indexInContext.size()));
-
-  std::vector<long> headOf;
-  for (auto & l : contextIndexes)
-  {
-    if (l == -1)
-      headOf.push_back(-1);
-    else
-    {
-      auto & head = config.getAsFeature(Config::headColName, l);
-      if (util::isEmpty(head) or head == "_")
-        headOf.push_back(-1);
-      else if  (indexInContext.count(std::stoi(head)))
-        headOf.push_back(std::stoi(head));
-      else
-        headOf.push_back(-1);
-    }
-  }
-
-  std::vector<std::vector<long>> childs(headOf.size());
-  for (unsigned int i = 0; i < headOf.size(); i++)
-    if (headOf[i] != -1)
-      childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]);
-
-  std::vector<long> treeComputationOrder;
-  std::vector<bool> treeIsComputed(contextIndexes.size(), false);
-
-  std::function<void(long)> depthFirst;
-  depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root)
-  {
-    if (!indexInContext.count(root))
-      return;
-
-    if (treeIsComputed[indexInContext[root]])
-      return;
-
-    for (auto child : childs[indexInContext[root]])
-      depthFirst(child);
-
-    treeIsComputed[indexInContext[root]] = true;
-    treeComputationOrder.push_back(indexInContext[root]);
-  };
-
-  for (auto & l : focusedBufferIndexes)
-    if (contextIndexes[leftBorder+l] != -1)
-      depthFirst(contextIndexes[leftBorder+l]);
-
-  for (auto & l : focusedStackIndexes)
-    if (config.hasStack(l))
-      depthFirst(config.getStack(l));
-
-  std::vector<long> context;
-  
-  for (auto & c : focusedBufferIndexes)
-    context.push_back(leftBorder+c);
-  for (auto & c : focusedStackIndexes)
-    if (config.hasStack(c) && indexInContext.count(config.getStack(c)))
-      context.push_back(indexInContext[config.getStack(c)]);
-    else
-      context.push_back(-1);
-  for (auto & c : treeComputationOrder)
-    context.push_back(c);
-  while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size())
-    context.push_back(-1);
-  for (auto & c : childs)
-  {
-    for (unsigned int i = 0; i < maxNbChilds; i++)
-      if (i < c.size())
-        context.push_back(indexInContext[c[i]]);
-      else
-        context.push_back(-1);
-  }
-  for (auto & l : contextIndexes)
-    for (auto & col : columns)
-      if (l == -1)
-        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-      else
-        context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
-
-  return {context};
-}
-
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index 8e973e4c1af83ba3d2b998956da3172f74e9b572..9730dd295a1194936f5f925ae4126b689e9956d3 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -2,11 +2,6 @@
 
 RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
 {
-  setBufferContext({0});
-  setStackContext({});
-  setBufferFocused({});
-  setStackFocused({});
-  setColumns({"FORM"});
 }
 
 torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
@@ -17,3 +12,8 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
   return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
 }
 
+std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
+{
+  return std::vector<std::vector<long>>();
+}
+
diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ebcfbfd9e6a6077eb280ff3d7e84c4fc7129fc8d
--- /dev/null
+++ b/torch_modules/src/RawInputLSTM.cpp
@@ -0,0 +1,40 @@
+#include "RawInputLSTM.hpp"
+
+RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
+{
+  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+}
+
+torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input)
+{
+  return lstm(input.narrow(1, firstInputIndex, getInputSize()));
+}
+
+std::size_t RawInputLSTMImpl::getOutputSize()
+{
+  return lstm->getOutputSize(leftWindow + rightWindow + 1);
+}
+
+std::size_t RawInputLSTMImpl::getInputSize()
+{
+  return leftWindow + rightWindow + 1;
+}
+
+void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+{
+  if (leftWindow < 0 or rightWindow < 0)
+    return;
+
+  for (int i = 0; i < leftWindow; i++)
+    if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
+      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
+    else
+      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+  for (int i = 0; i <= rightWindow; i++)
+    if (config.hasCharacter(config.getCharacterIndex()+i))
+      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
+    else
+      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+}
+
diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..283358c8483ee83952704524046e05a787e21217
--- /dev/null
+++ b/torch_modules/src/SplitTransLSTM.cpp
@@ -0,0 +1,33 @@
+#include "SplitTransLSTM.hpp"
+#include "Transition.hpp"
+
+SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
+{
+  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+}
+
+torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
+{
+  return lstm(input.narrow(1, firstInputIndex, getInputSize()));
+}
+
+std::size_t SplitTransLSTMImpl::getOutputSize()
+{
+  return lstm->getOutputSize(maxNbTrans);
+}
+
+std::size_t SplitTransLSTMImpl::getInputSize()
+{
+  return maxNbTrans;
+}
+
+void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
+{
+  auto & splitTransitions = config.getAppliableSplitTransitions();
+  for (int i = 0; i < maxNbTrans; i++)
+    if (i < (int)splitTransitions.size())
+      context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
+    else
+      context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+}
+
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2af75a3ee4413f9ee884b9a8e11370941754ddfa
--- /dev/null
+++ b/torch_modules/src/Submodule.cpp
@@ -0,0 +1,7 @@
+#include "Submodule.hpp"
+
+void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
+{
+  this->firstInputIndex = firstInputIndex;
+}
+