diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 4951e5dce72e57b6d02f74598e426c248ce0ddcb..0684d56fcad4bf1e2866103378bc4e9878354667 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -19,6 +19,8 @@ class Classifier
 
   void initNeuralNetwork(const std::vector<std::string> & definition);
   void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
+  void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
+  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
 
   public :
 
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index d9ead9a52e3dbf7e0beebd6ff08bc515712862a9..113ee9bda8932ccb86e95e8521eb5292dc5d429f 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -1,7 +1,9 @@
 #include "Classifier.hpp"
 #include "util.hpp"
 #include "LSTMNetwork.hpp"
+#include "GRUNetwork.hpp"
 #include "RandomNetwork.hpp"
+#include "ModularNetwork.hpp"
 
 Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
 {
@@ -87,6 +89,10 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
   else if (networkType == "LSTM")
     initLSTM(definition, curIndex, nbOutputsPerState);
+  else if (networkType == "GRU")
+    initGRU(definition, curIndex, nbOutputsPerState);
+  else if (networkType == "Modular")
+    initModular(definition, curIndex, nbOutputsPerState);
   else
     util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType));
 
@@ -349,3 +355,241 @@ void Classifier::setState(const std::string & state)
   nn->setState(state);
 }
 
+void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
+{
+  int unknownValueThreshold;
+  std::vector<int> bufferContext, stackContext;
+  std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
+  std::vector<int> focusedBuffer, focusedStack;
+  std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
+  std::vector<int> maxNbElements;
+  std::vector<int> treeEmbeddingNbElems;
+  std::vector<std::pair<int, float>> mlp;
+  int rawInputLeftWindow, rawInputRightWindow;
+  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
+  bool bilstm, drop2d;
+  float lstmDropout, embeddingsDropout, totalInputDropout;
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
+        {
+          unknownValueThreshold = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            bufferContext.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            stackContext.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
+        {
+          columns = util::split(sm.str(1), ' ');
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            focusedBuffer.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            focusedStack.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
+        {
+          focusedColumns = util::split(sm.str(1), ' ');
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            maxNbElements.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
+        {
+          rawInputLeftWindow = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
+        {
+          rawInputRightWindow = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
+        {
+          embeddingsSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
+        {
+          auto params = util::split(sm.str(1), ' ');
+          if (params.size() % 2)
+            util::myThrow("MLP must have even number of parameters");
+          for (unsigned int i = 0; i < params.size()/2; i++)
+            mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
+        {
+          contextLSTMSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
+        {
+          focusedLSTMSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
+        {
+          rawInputLSTMSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
+        {
+          splitTransLSTMSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
+        {
+          nbLayers = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
+        {
+          bilstm = sm.str(1) == "true";
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
+        {
+          lstmDropout = std::stof(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
+        {
+          totalInputDropout = std::stof(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
+        {
+          embeddingsDropout = std::stof(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
+        {
+          drop2d = sm.str(1) == "true";
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
+        {
+          treeEmbeddingColumns = util::split(sm.str(1), ' ');
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingBuffer.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingStack.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
+        {
+          for (auto & index : util::split(sm.str(1), ' '))
+            treeEmbeddingNbElems.emplace_back(std::stoi(index));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
+
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
+        {
+          treeEmbeddingSize = std::stoi(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
+
+  this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
+}
+
+void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
+{
+  std::string anyBlanks = "(?:(?:\\s|\\t)*)";
+  std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
+  std::vector<std::string> modulesDefinitions;
+
+  for (; curIndex < definition.size(); curIndex++)
+  {
+    if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
+    {
+      curIndex++;
+      break;
+    }
+    modulesDefinitions.emplace_back(definition[curIndex]);
+  }
+
+  this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
+}
+
diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextModule.hpp
similarity index 59%
rename from torch_modules/include/ContextLSTM.hpp
rename to torch_modules/include/ContextModule.hpp
index 3e3bbacac0e56cfd38e981279a0f6a54c1f41b3d..a9b609034023857a245c985990cc37f1d06f2bb7 100644
--- a/torch_modules/include/ContextLSTM.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -1,15 +1,18 @@
-#ifndef CONTEXTLSTM__H
-#define CONTEXTLSTM__H
+#ifndef CONTEXTMODULE__H
+#define CONTEXTMODULE__H
 
 #include <torch/torch.h>
 #include "Submodule.hpp"
+#include "MyModule.hpp"
+#include "GRU.hpp"
 #include "LSTM.hpp"
 
-class ContextLSTMImpl : public torch::nn::Module, public Submodule
+class ContextModuleImpl : public Submodule
 {
   private :
 
-  LSTM lstm{nullptr};
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<std::string> columns;
   std::vector<int> bufferContext;
   std::vector<int> stackContext;
@@ -18,13 +21,13 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
 
   public :
 
-  ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
+  ContextModuleImpl(const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
-TORCH_MODULE(ContextLSTM);
+TORCH_MODULE(ContextModule);
 
 #endif
 
diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
similarity index 59%
rename from torch_modules/include/DepthLayerTreeEmbedding.hpp
rename to torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index 2a8f7e8ca0ccd8fea1313e4b4437700c9bdd6bef..cd6c33df504b5e977ae1cece09f6bf92dec60e7e 100644
--- a/torch_modules/include/DepthLayerTreeEmbedding.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -3,9 +3,11 @@
 
 #include <torch/torch.h>
 #include "Submodule.hpp"
+#include "MyModule.hpp"
 #include "LSTM.hpp"
+#include "GRU.hpp"
 
-class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
+class DepthLayerTreeEmbeddingModule : public Submodule
 {
   private :
 
@@ -13,17 +15,16 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
   std::vector<std::string> columns;
   std::vector<int> focusedBuffer;
   std::vector<int> focusedStack;
-  std::vector<LSTM> depthLstm;
+  std::vector<std::shared_ptr<MyModule>> depthModules;
 
   public :
 
-  DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options);
+  DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
-TORCH_MODULE(DepthLayerTreeEmbedding);
 
 #endif
 
diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnModule.hpp
similarity index 54%
rename from torch_modules/include/FocusedColumnLSTM.hpp
rename to torch_modules/include/FocusedColumnModule.hpp
index fd5d915df6d42d24294e6a75dd42c87d6e81dec1..9c2732aa82d0b23490a3d36ef527c26aa55a1822 100644
--- a/torch_modules/include/FocusedColumnLSTM.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -1,28 +1,29 @@
-#ifndef FOCUSEDCOLUMNLSTM__H
-#define FOCUSEDCOLUMNLSTM__H
+#ifndef FOCUSEDCOLUMNMODULE__H
+#define FOCUSEDCOLUMNMODULE__H
 
 #include <torch/torch.h>
 #include "Submodule.hpp"
+#include "MyModule.hpp"
 #include "LSTM.hpp"
+#include "GRU.hpp"
 
-class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
+class FocusedColumnModule : public Submodule
 {
   private :
 
-  LSTM lstm{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<int> focusedBuffer, focusedStack;
   std::string column;
   int maxNbElements;
 
   public :
 
-  FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
-TORCH_MODULE(FocusedColumnLSTM);
 
 #endif
 
diff --git a/torch_modules/include/GRU.hpp b/torch_modules/include/GRU.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..7980db43d5a8a808b2af96ffe59c12ec37fa539e
--- /dev/null
+++ b/torch_modules/include/GRU.hpp
@@ -0,0 +1,23 @@
+#ifndef GRU__H
+#define GRU__H
+
+#include <torch/torch.h>
+#include "MyModule.hpp"
+
+class GRUImpl : public MyModule
+{
+  private :
+
+  torch::nn::GRU gru{nullptr};
+  bool outputAll;
+
+  public :
+
+  GRUImpl(int inputSize, int outputSize, ModuleOptions options);
+  torch::Tensor forward(torch::Tensor input);
+  int getOutputSize(int sequenceLength);
+};
+TORCH_MODULE(GRU);
+
+#endif
+
diff --git a/torch_modules/include/GRUNetwork.hpp b/torch_modules/include/GRUNetwork.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..ecff8a05c7604c45b04fbc9f2caf6c0637dc7558
--- /dev/null
+++ b/torch_modules/include/GRUNetwork.hpp
@@ -0,0 +1,36 @@
+#ifndef GRUNETWORK__H
+#define GRUNETWORK__H
+
+#include "NeuralNetwork.hpp"
+#include "ContextModule.hpp"
+#include "RawInputModule.hpp"
+#include "SplitTransModule.hpp"
+#include "FocusedColumnModule.hpp"
+#include "DepthLayerTreeEmbeddingModule.hpp"
+#include "MLP.hpp"
+
+class GRUNetworkImpl : public NeuralNetworkImpl
+{
+//  private :
+//
+//  torch::nn::Embedding wordEmbeddings{nullptr};
+//  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
+//  torch::nn::Dropout embeddingsDropout{nullptr};
+//  torch::nn::Dropout inputDropout{nullptr};
+//
+//  MLP mlp{nullptr};
+//  ContextModule contextGRU{nullptr};
+//  RawInputModule rawInputGRU{nullptr};
+//  SplitTransModule splitTransGRU{nullptr};
+//  DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
+//  std::vector<FocusedColumnModule> focusedLstms;
+//  std::map<std::string,torch::nn::Linear> outputLayersPerState;
+
+  public :
+
+  GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
+};
+
+#endif
diff --git a/torch_modules/include/LSTM.hpp b/torch_modules/include/LSTM.hpp
index c45cb9fff41aa6f9fc0883bcf7e37c9db5e3a8e9..7e4ed0c2e408cbb6cbed16b33082aec48bb69105 100644
--- a/torch_modules/include/LSTM.hpp
+++ b/torch_modules/include/LSTM.hpp
@@ -2,14 +2,10 @@
 #define LSTM__H
 
 #include <torch/torch.h>
-#include "fmt/core.h"
+#include "MyModule.hpp"
 
-class LSTMImpl : public torch::nn::Module
+class LSTMImpl : public MyModule
 {
-  public :
-
-  using LSTMOptions = std::tuple<bool,bool,int,float,bool>;
-
   private :
 
   torch::nn::LSTM lstm{nullptr};
@@ -17,7 +13,7 @@ class LSTMImpl : public torch::nn::Module
 
   public :
 
-  LSTMImpl(int inputSize, int outputSize, LSTMOptions options);
+  LSTMImpl(int inputSize, int outputSize, ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
   int getOutputSize(int sequenceLength);
 };
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 76b9303a8cb7245fe27e42d9b7f9673a832ef2e3..d83acf5d0ada988972ca43e749644ad16fc49bae 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -2,29 +2,29 @@
 #define LSTMNETWORK__H
 
 #include "NeuralNetwork.hpp"
-#include "ContextLSTM.hpp"
-#include "RawInputLSTM.hpp"
-#include "SplitTransLSTM.hpp"
-#include "FocusedColumnLSTM.hpp"
+#include "ContextModule.hpp"
+#include "RawInputModule.hpp"
+#include "SplitTransModule.hpp"
+#include "FocusedColumnModule.hpp"
+#include "DepthLayerTreeEmbeddingModule.hpp"
 #include "MLP.hpp"
-#include "DepthLayerTreeEmbedding.hpp"
 
 class LSTMNetworkImpl : public NeuralNetworkImpl
 {
-  private :
-
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
-  torch::nn::Dropout embeddingsDropout{nullptr};
-  torch::nn::Dropout inputDropout{nullptr};
-
-  MLP mlp{nullptr};
-  ContextLSTM contextLSTM{nullptr};
-  RawInputLSTM rawInputLSTM{nullptr};
-  SplitTransLSTM splitTransLSTM{nullptr};
-  DepthLayerTreeEmbedding treeEmbedding{nullptr};
-  std::vector<FocusedColumnLSTM> focusedLstms;
-  std::map<std::string,torch::nn::Linear> outputLayersPerState;
+//  private :
+//
+//  torch::nn::Embedding wordEmbeddings{nullptr};
+//  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
+//  torch::nn::Dropout embeddingsDropout{nullptr};
+//  torch::nn::Dropout inputDropout{nullptr};
+//
+//  MLP mlp{nullptr};
+//  ContextModule contextLSTM{nullptr};
+//  RawInputModule rawInputLSTM{nullptr};
+//  SplitTransModule splitTransLSTM{nullptr};
+//  DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
+//  std::vector<FocusedColumnModule> focusedLstms;
+//  std::map<std::string,torch::nn::Linear> outputLayersPerState;
 
   public :
 
diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp
index be272f1cd1369a7b1290aefd1868265111ac00da..bd108a461107e1fa4c0bb609c733f03515b3e785 100644
--- a/torch_modules/include/MLP.hpp
+++ b/torch_modules/include/MLP.hpp
@@ -13,7 +13,7 @@ class MLPImpl : public torch::nn::Module
 
   public :
 
-  MLPImpl(int inputSize, std::vector<std::pair<int, float>> params);
+  MLPImpl(int inputSize, std::string definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t outputSize() const;
 };
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..871c8bb3c1f00652a0b5a6849a412811c19d9094
--- /dev/null
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -0,0 +1,32 @@
+#ifndef MODULARNETWORK__H
+#define MODULARNETWORK__H
+
+#include "NeuralNetwork.hpp"
+#include "ContextModule.hpp"
+#include "RawInputModule.hpp"
+#include "SplitTransModule.hpp"
+#include "FocusedColumnModule.hpp"
+#include "DepthLayerTreeEmbeddingModule.hpp"
+#include "MLP.hpp"
+
+class ModularNetworkImpl : public NeuralNetworkImpl
+{
+  private :
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
+  torch::nn::Dropout embeddingsDropout{nullptr};
+  torch::nn::Dropout inputDropout{nullptr};
+
+  MLP mlp{nullptr};
+  std::vector<std::shared_ptr<Submodule>> modules;
+  std::map<std::string,torch::nn::Linear> outputLayersPerState;
+
+  public :
+
+  ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
+};
+
+#endif
diff --git a/torch_modules/include/MyModule.hpp b/torch_modules/include/MyModule.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..02018a3473b892121506181c24bdc5642827f3e6
--- /dev/null
+++ b/torch_modules/include/MyModule.hpp
@@ -0,0 +1,25 @@
+#ifndef MYMODULE__H
+#define MYMODULE__H
+
+#include <torch/torch.h>
+
+class MyModule : public torch::nn::Module
+{
+  public :
+
+  struct ModuleOptions : std::tuple<bool,bool,int,float,bool>
+  {
+    ModuleOptions(bool batchFirst){std::get<0>(*this)=batchFirst;};
+    ModuleOptions & bidirectional(bool val) {std::get<1>(*this)=val; return *this;}
+    ModuleOptions & num_layers(int num) {std::get<2>(*this)=num; return *this;}
+    ModuleOptions & dropout(float val) {std::get<3>(*this)=val; return *this;}
+    ModuleOptions & complete(bool val) {std::get<4>(*this)=val; return *this;}
+  };
+
+  public :
+
+  virtual int getOutputSize(int sequenceLength) = 0;
+  virtual torch::Tensor forward(torch::Tensor) = 0;
+};
+
+#endif
diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputModule.hpp
similarity index 56%
rename from torch_modules/include/RawInputLSTM.hpp
rename to torch_modules/include/RawInputModule.hpp
index 0e08560836b735f181849571ff0beec8f02bc335..0134f974d7640f77cc8334f2828c71863a939230 100644
--- a/torch_modules/include/RawInputLSTM.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -1,26 +1,27 @@
-#ifndef RAWINPUTLSTM__H
-#define RAWINPUTLSTM__H
+#ifndef RAWINPUTMODULE__H
+#define RAWINPUTMODULE__H
 
 #include <torch/torch.h>
 #include "Submodule.hpp"
+#include "MyModule.hpp"
 #include "LSTM.hpp"
+#include "GRU.hpp"
 
-class RawInputLSTMImpl : public torch::nn::Module, public Submodule
+class RawInputModule : public Submodule
 {
   private :
 
-  LSTM lstm{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
   int leftWindow, rightWindow;
 
   public :
 
-  RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
-TORCH_MODULE(RawInputLSTM);
 
 #endif
 
diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransModule.hpp
similarity index 56%
rename from torch_modules/include/SplitTransLSTM.hpp
rename to torch_modules/include/SplitTransModule.hpp
index 85d542ce8510bd0c1d11b2ca6c1f280aeb386d55..8be8e633fab7f182242940eee1308b3cabe0d676 100644
--- a/torch_modules/include/SplitTransLSTM.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -1,26 +1,27 @@
-#ifndef SPLITTRANSLSTM__H
-#define SPLITTRANSLSTM__H
+#ifndef SPLITTRANSMODULE__H
+#define SPLITTRANSMODULE__H
 
 #include <torch/torch.h>
 #include "Submodule.hpp"
+#include "MyModule.hpp"
 #include "LSTM.hpp"
+#include "GRU.hpp"
 
-class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
+class SplitTransModule : public Submodule
 {
   private :
 
-  LSTM lstm{nullptr};
+  std::shared_ptr<MyModule> myModule{nullptr};
   int maxNbTrans;
 
   public :
 
-  SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options);
+  SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
-TORCH_MODULE(SplitTransLSTM);
 
 #endif
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index cc381013aea518aeefe8422b36537283d5d0da94..77c1a4feb08628615d1d163369f0a9272970d475 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -1,10 +1,11 @@
 #ifndef SUBMODULE__H
 #define SUBMODULE__H
 
+#include <torch/torch.h>
 #include "Dict.hpp"
 #include "Config.hpp"
 
-class Submodule
+class Submodule : public torch::nn::Module
 {
   protected :
 
@@ -16,6 +17,7 @@ class Submodule
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
+  virtual torch::Tensor forward(torch::Tensor input) = 0;
 };
 
 #endif
diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp
deleted file mode 100644
index d24778878ec51303d55b3b0e7a3bc25f0fbdc9cc..0000000000000000000000000000000000000000
--- a/torch_modules/src/ContextLSTM.cpp
+++ /dev/null
@@ -1,60 +0,0 @@
-#include "ContextLSTM.hpp"
-
-ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold)
-{
-  lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options));
-}
-
-std::size_t ContextLSTMImpl::getOutputSize()
-{
-  return lstm->getOutputSize(bufferContext.size()+stackContext.size());
-}
-
-std::size_t ContextLSTMImpl::getInputSize()
-{
-  return columns.size()*(bufferContext.size()+stackContext.size());
-}
-
-void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
-{
-  std::vector<long> contextIndexes;
-
-  for (int index : bufferContext)
-    contextIndexes.emplace_back(config.getRelativeWordIndex(index));
-
-  for (int index : stackContext)
-    if (config.hasStack(index))
-      contextIndexes.emplace_back(config.getStack(index));
-    else
-      contextIndexes.emplace_back(-1);
-
-  for (auto index : contextIndexes)
-    for (auto & col : columns)
-      if (index == -1)
-      {
-        for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-      }
-      else
-      {
-        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
-
-        for (auto & contextElement : context)
-          contextElement.push_back(dictIndex);
-
-        for (auto & targetCol : unknownValueColumns)
-          if (col == targetCol)
-            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
-              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
-      }
-}
-
-torch::Tensor ContextLSTMImpl::forward(torch::Tensor input)
-{
-  auto context = input.narrow(1, firstInputIndex, getInputSize());
-
-  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
-
-  return lstm(context);
-}
-
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d2c1e6af087e2696960a6cdfefc03a9e3827f2b6
--- /dev/null
+++ b/torch_modules/src/ContextModule.cpp
@@ -0,0 +1,98 @@
+#include "ContextModule.hpp"
+
+ContextModuleImpl::ContextModuleImpl(const std::string & definition)
+{
+  std::regex regex("(?:(?:\\s|\\t)*)Unk\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            unknownValueThreshold = std::stoi(sm.str(1));
+
+            for (auto & index : util::split(sm.str(2), ' '))
+              bufferContext.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(3), ' '))
+              stackContext.emplace_back(std::stoi(index));
+
+            columns = util::split(sm.str(4), ' ');
+
+            auto subModuleType = sm.str(5);
+            auto subModuleArguments = util::split(sm.str(6), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int inSize = std::stoi(sm.str(7));
+            int outSize = std::stoi(sm.str(8));
+
+            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+}
+
+std::size_t ContextModuleImpl::getOutputSize()
+{
+  return myModule->getOutputSize(bufferContext.size()+stackContext.size());
+}
+
+std::size_t ContextModuleImpl::getInputSize()
+{
+  return columns.size()*(bufferContext.size()+stackContext.size());
+}
+
+void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
+{
+  std::vector<long> contextIndexes;
+
+  for (int index : bufferContext)
+    contextIndexes.emplace_back(config.getRelativeWordIndex(index));
+
+  for (int index : stackContext)
+    if (config.hasStack(index))
+      contextIndexes.emplace_back(config.getStack(index));
+    else
+      contextIndexes.emplace_back(-1);
+
+  for (auto index : contextIndexes)
+    for (auto & col : columns)
+      if (index == -1)
+      {
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      }
+      else
+      {
+        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
+
+        for (auto & contextElement : context)
+          contextElement.push_back(dictIndex);
+
+        for (auto & targetCol : unknownValueColumns)
+          if (col == targetCol)
+            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
+              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
+      }
+}
+
+torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
+{
+  auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
+
+  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)});
+
+  return myModule->forward(context);
+}
+
diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
similarity index 68%
rename from torch_modules/src/DepthLayerTreeEmbedding.cpp
rename to torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index b506f9219fd8284094960907975294fbd3a5b28a..f3d7bb2881e5a3d187aa2e95d426168bbabce976 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -1,13 +1,13 @@
-#include "DepthLayerTreeEmbedding.hpp"
+#include "DepthLayerTreeEmbeddingModule.hpp"
 
-DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options) :
+DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) :
   maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
 {
   for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
-    depthLstm.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
+    depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
 }
 
-torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
+torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
 
@@ -17,24 +17,24 @@ torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input)
   for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
     for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
     {
-      outputs.emplace_back(depthLstm[depth](context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
+      outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
       offset += maxElemPerDepth[depth]*columns.size();
     }
 
   return torch::cat(outputs, 1);
 }
 
-std::size_t DepthLayerTreeEmbeddingImpl::getOutputSize()
+std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
 {
   std::size_t outputSize = 0;
 
   for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
-    outputSize += depthLstm[depth]->getOutputSize(maxElemPerDepth[depth]);
+    outputSize += depthModules[depth]->getOutputSize(maxElemPerDepth[depth]);
 
   return outputSize*(focusedBuffer.size()+focusedStack.size());
 }
 
-std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
+std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
 {
   int inputSize = 0;
   for (int maxElem : maxElemPerDepth)
@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnModule.cpp
similarity index 72%
rename from torch_modules/src/FocusedColumnLSTM.cpp
rename to torch_modules/src/FocusedColumnModule.cpp
index e39af636c817fdc1677cfd9131b85ec7fb1bd3ba..717bc747c85be8521271d80499698af525741585 100644
--- a/torch_modules/src/FocusedColumnLSTM.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -1,30 +1,30 @@
-#include "FocusedColumnLSTM.hpp"
+#include "FocusedColumnModule.hpp"
 
-FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
+FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
 {
-  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
 }
 
-torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input)
+torch::Tensor FocusedColumnModule::forward(torch::Tensor input)
 {
   std::vector<torch::Tensor> outputs;
   for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
-    outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
+    outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
 
   return torch::cat(outputs, 1);
 }
 
-std::size_t FocusedColumnLSTMImpl::getOutputSize()
+std::size_t FocusedColumnModule::getOutputSize()
 {
-  return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements);
+  return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements);
 }
 
-std::size_t FocusedColumnLSTMImpl::getInputSize()
+std::size_t FocusedColumnModule::getInputSize()
 {
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/GRU.cpp b/torch_modules/src/GRU.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fa6de5ccf0d594166f2544465dc5903a204692d4
--- /dev/null
+++ b/torch_modules/src/GRU.cpp
@@ -0,0 +1,34 @@
+#include "GRU.hpp"
+
+GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
+{
+  auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize)
+    .batch_first(std::get<0>(options))
+    .bidirectional(std::get<1>(options))
+    .num_layers(std::get<2>(options))
+    .dropout(std::get<3>(options));
+
+  gru = register_module("gru", torch::nn::GRU(gruOptions));
+}
+
+torch::Tensor GRUImpl::forward(torch::Tensor input)
+{
+  auto gruOut = std::get<0>(gru(input));
+
+  if (outputAll)
+    return gruOut.reshape({gruOut.size(0), -1});
+
+  if (gru->options.bidirectional())
+    return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1);
+
+  return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1);
+}
+
+int GRUImpl::getOutputSize(int sequenceLength)
+{
+  if (outputAll)
+    return sequenceLength * gru->options.hidden_size() * (gru->options.bidirectional() ? 2 : 1);
+
+  return gru->options.hidden_size() * (gru->options.bidirectional() ? 4 : 1);
+}
+
diff --git a/torch_modules/src/GRUNetwork.cpp b/torch_modules/src/GRUNetwork.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..dc7f36609d5a83f94486ee30dc3025bc30ca5bcc
--- /dev/null
+++ b/torch_modules/src/GRUNetwork.cpp
@@ -0,0 +1,120 @@
+#include "GRUNetwork.hpp"
+
+GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
+{
+//  MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false};
+//  auto gruOptionsAll = gruOptions;
+//  std::get<4>(gruOptionsAll) = true;
+//
+//  int currentOutputSize = embeddingsSize;
+//  int currentInputSize = 1;
+//
+//  contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold));
+//  contextGRU->setFirstInputIndex(currentInputSize);
+//  currentOutputSize += contextGRU->getOutputSize();
+//  currentInputSize += contextGRU->getInputSize();
+//
+//  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
+//  {
+//    rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll));
+//    rawInputGRU->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += rawInputGRU->getOutputSize();
+//    currentInputSize += rawInputGRU->getInputSize();
+//  }
+//
+//  if (!treeEmbeddingColumns.empty())
+//  {
+//    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions));
+//    treeEmbedding->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += treeEmbedding->getOutputSize();
+//    currentInputSize += treeEmbedding->getInputSize();
+//  }
+//
+//  splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll));
+//  splitTransGRU->setFirstInputIndex(currentInputSize);
+//  currentOutputSize += splitTransGRU->getOutputSize();
+//  currentInputSize += splitTransGRU->getInputSize();
+//
+//  for (unsigned int i = 0; i < focusedColumns.size(); i++)
+//  {
+//    focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions)));
+//    focusedLstms.back()->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += focusedLstms.back()->getOutputSize();
+//    currentInputSize += focusedLstms.back()->getInputSize();
+//  }
+//
+//  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
+//  if (drop2d)
+//    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
+//  else
+//    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
+//  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
+//
+//  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
+//
+//  for (auto & it : nbOutputsPerState)
+//    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
+}
+
+torch::Tensor GRUNetworkImpl::forward(torch::Tensor input)
+{
+  return input;
+//  if (input.dim() == 1)
+//    input = input.unsqueeze(0);
+//
+//  auto embeddings = wordEmbeddings(input);
+//  if (embeddingsDropout2d.is_empty())
+//    embeddings = embeddingsDropout(embeddings);
+//  else
+//    embeddings = embeddingsDropout2d(embeddings);
+//
+//  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
+//
+//  outputs.emplace_back(contextGRU(embeddings));
+//
+//  if (!rawInputGRU.is_empty())
+//    outputs.emplace_back(rawInputGRU(embeddings));
+//
+//  if (!treeEmbedding.is_empty())
+//    outputs.emplace_back(treeEmbedding(embeddings));
+//
+//  outputs.emplace_back(splitTransGRU(embeddings));
+//
+//  for (auto & gru : focusedLstms)
+//    outputs.emplace_back(gru(embeddings));
+//
+//  auto totalInput = inputDropout(torch::cat(outputs, 1));
+//
+//  return outputLayersPerState.at(getState())(mlp(totalInput));
+}
+
+std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  std::vector<std::vector<long>> context;
+  return context;
+//  if (dict.size() >= maxNbEmbeddings)
+//    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
+//
+//  context.emplace_back();
+//
+//  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
+//
+//  contextGRU->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!rawInputGRU.is_empty())
+//    rawInputGRU->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!treeEmbedding.is_empty())
+//    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  splitTransGRU->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  for (auto & gru : focusedLstms)
+//    gru->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!mustSplitUnknown() && context.size() > 1)
+//    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
+//
+//  return context;
+}
+
diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp
index af89a3dedddc3750451f75442213eeb52482dfda..2844b17a256bac5de90017343fc4c7b2ad466e89 100644
--- a/torch_modules/src/LSTM.cpp
+++ b/torch_modules/src/LSTM.cpp
@@ -1,6 +1,6 @@
 #include "LSTM.hpp"
 
-LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options))
+LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
 {
   auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
     .batch_first(std::get<0>(options))
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 476e24dc938bd302db62194fae0e4cb665f82dab..e46d175aec9d5e83f1d365d40cdd0b7e5c3dc977 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -2,117 +2,119 @@
 
 LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
 {
-  LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
-  auto lstmOptionsAll = lstmOptions;
-  std::get<4>(lstmOptionsAll) = true;
-
-  int currentOutputSize = embeddingsSize;
-  int currentInputSize = 1;
-
-  contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold));
-  contextLSTM->setFirstInputIndex(currentInputSize);
-  currentOutputSize += contextLSTM->getOutputSize();
-  currentInputSize += contextLSTM->getInputSize();
-
-  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
-  {
-    rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
-    rawInputLSTM->setFirstInputIndex(currentInputSize);
-    currentOutputSize += rawInputLSTM->getOutputSize();
-    currentInputSize += rawInputLSTM->getInputSize();
-  }
-
-  if (!treeEmbeddingColumns.empty())
-  {
-    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
-    treeEmbedding->setFirstInputIndex(currentInputSize);
-    currentOutputSize += treeEmbedding->getOutputSize();
-    currentInputSize += treeEmbedding->getInputSize();
-  }
-
-  splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, lstmOptionsAll));
-  splitTransLSTM->setFirstInputIndex(currentInputSize);
-  currentOutputSize += splitTransLSTM->getOutputSize();
-  currentInputSize += splitTransLSTM->getInputSize();
-
-  for (unsigned int i = 0; i < focusedColumns.size(); i++)
-  {
-    focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions)));
-    focusedLstms.back()->setFirstInputIndex(currentInputSize);
-    currentOutputSize += focusedLstms.back()->getOutputSize();
-    currentInputSize += focusedLstms.back()->getInputSize();
-  }
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  if (drop2d)
-    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
-  else
-    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
-  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
-
-  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
-
-  for (auto & it : nbOutputsPerState)
-    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
+//  MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false};
+//  auto moduleOptionsAll = moduleOptions;
+//  std::get<4>(moduleOptionsAll) = true;
+//
+//  int currentOutputSize = embeddingsSize;
+//  int currentInputSize = 1;
+//
+//  contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold));
+//  contextLSTM->setFirstInputIndex(currentInputSize);
+//  currentOutputSize += contextLSTM->getOutputSize();
+//  currentInputSize += contextLSTM->getInputSize();
+//
+//  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
+//  {
+//    rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll));
+//    rawInputLSTM->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += rawInputLSTM->getOutputSize();
+//    currentInputSize += rawInputLSTM->getInputSize();
+//  }
+//
+//  if (!treeEmbeddingColumns.empty())
+//  {
+//    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions));
+//    treeEmbedding->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += treeEmbedding->getOutputSize();
+//    currentInputSize += treeEmbedding->getInputSize();
+//  }
+//
+//  splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll));
+//  splitTransLSTM->setFirstInputIndex(currentInputSize);
+//  currentOutputSize += splitTransLSTM->getOutputSize();
+//  currentInputSize += splitTransLSTM->getInputSize();
+//
+//  for (unsigned int i = 0; i < focusedColumns.size(); i++)
+//  {
+//    focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions)));
+//    focusedLstms.back()->setFirstInputIndex(currentInputSize);
+//    currentOutputSize += focusedLstms.back()->getOutputSize();
+//    currentInputSize += focusedLstms.back()->getInputSize();
+//  }
+//
+//  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
+//  if (drop2d)
+//    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
+//  else
+//    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
+//  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
+//
+//  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
+//
+//  for (auto & it : nbOutputsPerState)
+//    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
 }
 
 torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
 {
-  if (input.dim() == 1)
-    input = input.unsqueeze(0);
-
-  auto embeddings = wordEmbeddings(input);
-  if (embeddingsDropout2d.is_empty())
-    embeddings = embeddingsDropout(embeddings);
-  else
-    embeddings = embeddingsDropout2d(embeddings);
-
-  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
-
-  outputs.emplace_back(contextLSTM(embeddings));
-
-  if (!rawInputLSTM.is_empty())
-    outputs.emplace_back(rawInputLSTM(embeddings));
-
-  if (!treeEmbedding.is_empty())
-    outputs.emplace_back(treeEmbedding(embeddings));
-
-  outputs.emplace_back(splitTransLSTM(embeddings));
-
-  for (auto & lstm : focusedLstms)
-    outputs.emplace_back(lstm(embeddings));
-
-  auto totalInput = inputDropout(torch::cat(outputs, 1));
-
-  return outputLayersPerState.at(getState())(mlp(totalInput));
+  return input;
+//  if (input.dim() == 1)
+//    input = input.unsqueeze(0);
+//
+//  auto embeddings = wordEmbeddings(input);
+//  if (embeddingsDropout2d.is_empty())
+//    embeddings = embeddingsDropout(embeddings);
+//  else
+//    embeddings = embeddingsDropout2d(embeddings);
+//
+//  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
+//
+//  outputs.emplace_back(contextLSTM(embeddings));
+//
+//  if (!rawInputLSTM.is_empty())
+//    outputs.emplace_back(rawInputLSTM(embeddings));
+//
+//  if (!treeEmbedding.is_empty())
+//    outputs.emplace_back(treeEmbedding(embeddings));
+//
+//  outputs.emplace_back(splitTransLSTM(embeddings));
+//
+//  for (auto & lstm : focusedLstms)
+//    outputs.emplace_back(lstm(embeddings));
+//
+//  auto totalInput = inputDropout(torch::cat(outputs, 1));
+//
+//  return outputLayersPerState.at(getState())(mlp(totalInput));
 }
 
 std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
-  if (dict.size() >= maxNbEmbeddings)
-    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
-
   std::vector<std::vector<long>> context;
-  context.emplace_back();
-
-  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
-
-  contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
-
-  if (!rawInputLSTM.is_empty())
-    rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
-
-  if (!treeEmbedding.is_empty())
-    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
-
-  splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
-
-  for (auto & lstm : focusedLstms)
-    lstm->addToContext(context, dict, config, mustSplitUnknown());
-
-  if (!mustSplitUnknown() && context.size() > 1)
-    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
-
   return context;
+//  if (dict.size() >= maxNbEmbeddings)
+//    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
+//
+//  context.emplace_back();
+//
+//  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
+//
+//  contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!rawInputLSTM.is_empty())
+//    rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!treeEmbedding.is_empty())
+//    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  for (auto & lstm : focusedLstms)
+//    lstm->addToContext(context, dict, config, mustSplitUnknown());
+//
+//  if (!mustSplitUnknown() && context.size() > 1)
+//    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
+//
+//  return context;
 }
 
diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp
index 03880ecca0554d1ab4ad1b0c92fca4f1a0ea2481..ede2ff63f46ee50a640a8e041389974caa5fff64 100644
--- a/torch_modules/src/MLP.cpp
+++ b/torch_modules/src/MLP.cpp
@@ -1,8 +1,25 @@
 #include "MLP.hpp"
+#include "util.hpp"
 #include "fmt/core.h"
+#include <regex>
 
-MLPImpl::MLPImpl(int inputSize, std::vector<std::pair<int, float>> params)
+MLPImpl::MLPImpl(int inputSize, std::string definition)
 {
+  std::regex regex("(?:(?:\\s|\\t)*)\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  std::vector<std::pair<int, float>> params;
+  if (!util::doIfNameMatch(regex, definition, [this,&definition,&params](auto sm)
+        {
+          try
+          {
+            auto splited = util::split(sm.str(1), ' ');
+            for (unsigned int i = 0; i < splited.size()/2; i++)
+            {
+              params.emplace_back(std::stoi(splited[2*i]), std::stof(splited[2*i+1]));
+            }
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
+
   int inSize = inputSize;
 
   for (auto & param : params)
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..21edb84ffb337b5f0bf544e5bf71580153888335
--- /dev/null
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -0,0 +1,77 @@
+#include "ModularNetwork.hpp"
+
+ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
+{
+  std::string anyBlanks = "(?:(?:\\s|\\t)*)";
+  auto splitLine = [anyBlanks](std::string line)
+  {
+    std::pair<std::string,std::string> result;
+    util::doIfNameMatch(std::regex(fmt::format("{}(\\S+){}:{}(.+)",anyBlanks,anyBlanks,anyBlanks)),line,[&result](auto sm)
+        {
+          result.first = sm.str(1);
+          result.second = sm.str(2);
+        });
+    return result;
+  };
+
+  int currentInputSize = 0;
+  int currentOutputSize = 0;
+  std::string mlpDef;
+  for (auto & line : definitions)
+  {
+    auto splited = splitLine(line);
+    std::string name = fmt::format("{}_{}", modules.size(), splited.first);
+    if (splited.first == "Context")
+      modules.emplace_back(register_module(name, ContextModule(splited.second)));
+    else if (splited.first == "MLP")
+    {
+      mlpDef = splited.second;
+      continue;
+    }
+    else if (splited.first == "InputDropout")
+    {
+      inputDropout = register_module("inputDropout", torch::nn::Dropout(std::stof(splited.second)));
+      continue;
+    }
+    else
+      util::myThrow(fmt::format("unknown module '{}' for line '{}'", splited.first, line));
+
+    modules.back()->setFirstInputIndex(currentInputSize);
+    currentInputSize += modules.back()->getInputSize();
+    currentOutputSize += modules.back()->getOutputSize();
+  }
+
+  if (mlpDef.empty())
+    util::myThrow("no MLP definition found");
+  if (inputDropout.is_empty())
+    util::myThrow("no InputDropout definition found");
+
+  mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
+
+  for (auto & it : nbOutputsPerState)
+    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
+}
+
+torch::Tensor ModularNetworkImpl::forward(torch::Tensor input)
+{
+  if (input.dim() == 1)
+    input = input.unsqueeze(0);
+
+  std::vector<torch::Tensor> outputs;
+
+  for (auto & mod : modules)
+    outputs.emplace_back(mod->forward(input));
+
+  auto totalInput = inputDropout(torch::cat(outputs, 1));
+
+  return outputLayersPerState.at(getState())(mlp(totalInput));
+}
+
+std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  std::vector<std::vector<long>> context(1);
+  for (auto & mod : modules)
+    mod->addToContext(context, dict, config, mustSplitUnknown());
+  return context;
+}
+
diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputModule.cpp
similarity index 53%
rename from torch_modules/src/RawInputLSTM.cpp
rename to torch_modules/src/RawInputModule.cpp
index c6da426a7807b90bfd52eaf06abe7599c4c517c3..e14451dd206b2dea8a3d50f91a603fca97b74e26 100644
--- a/torch_modules/src/RawInputLSTM.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -1,26 +1,26 @@
-#include "RawInputLSTM.hpp"
+#include "RawInputModule.hpp"
 
-RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
+RawInputModule::RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
 {
-  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
 }
 
-torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input)
+torch::Tensor RawInputModule::forward(torch::Tensor input)
 {
-  return lstm(input.narrow(1, firstInputIndex, getInputSize()));
+  return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
 }
 
-std::size_t RawInputLSTMImpl::getOutputSize()
+std::size_t RawInputModule::getOutputSize()
 {
-  return lstm->getOutputSize(leftWindow + rightWindow + 1);
+  return myModule->getOutputSize(leftWindow + rightWindow + 1);
 }
 
-std::size_t RawInputLSTMImpl::getInputSize()
+std::size_t RawInputModule::getInputSize()
 {
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void RawInputModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp
deleted file mode 100644
index 99a1b35650e0b60c8c34c22f0a863d1ab1f8c990..0000000000000000000000000000000000000000
--- a/torch_modules/src/SplitTransLSTM.cpp
+++ /dev/null
@@ -1,34 +0,0 @@
-#include "SplitTransLSTM.hpp"
-#include "Transition.hpp"
-
-SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans)
-{
-  lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
-}
-
-torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input)
-{
-  return lstm(input.narrow(1, firstInputIndex, getInputSize()));
-}
-
-std::size_t SplitTransLSTMImpl::getOutputSize()
-{
-  return lstm->getOutputSize(maxNbTrans);
-}
-
-std::size_t SplitTransLSTMImpl::getInputSize()
-{
-  return maxNbTrans;
-}
-
-void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
-{
-  auto & splitTransitions = config.getAppliableSplitTransitions();
-  for (auto & contextElement : context)
-    for (int i = 0; i < maxNbTrans; i++)
-      if (i < (int)splitTransitions.size())
-        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
-      else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
-}
-
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..45a48df4488e884af925d6b50bfca69bdba73e60
--- /dev/null
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -0,0 +1,34 @@
+#include "SplitTransModule.hpp"
+#include "Transition.hpp"
+
+SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans)
+{
+  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+}
+
+torch::Tensor SplitTransModule::forward(torch::Tensor input)
+{
+  return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
+}
+
+std::size_t SplitTransModule::getOutputSize()
+{
+  return myModule->getOutputSize(maxNbTrans);
+}
+
+std::size_t SplitTransModule::getInputSize()
+{
+  return maxNbTrans;
+}
+
+void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+{
+  auto & splitTransitions = config.getAppliableSplitTransitions();
+  for (auto & contextElement : context)
+    for (int i = 0; i < maxNbTrans; i++)
+      if (i < (int)splitTransitions.size())
+        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
+      else
+        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+}
+