From ca72f53e82460ff67eded3de6d6ae7fa969ae80f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 29 Apr 2020 15:18:43 +0200
Subject: [PATCH] Implemented all modules

---
 reading_machine/include/Classifier.hpp        |   2 -
 reading_machine/src/Classifier.cpp            | 444 ------------------
 .../include/DepthLayerTreeEmbeddingModule.hpp |   6 +-
 torch_modules/include/FocusedColumnModule.hpp |   6 +-
 torch_modules/include/GRUNetwork.hpp          |  36 --
 torch_modules/include/LSTMNetwork.hpp         |  36 --
 torch_modules/include/ModularNetwork.hpp      |   6 +-
 torch_modules/include/RawInputModule.hpp      |   6 +-
 torch_modules/include/SplitTransModule.hpp    |   6 +-
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  61 ++-
 torch_modules/src/FocusedColumnModule.cpp     |  51 +-
 torch_modules/src/GRUNetwork.cpp              | 120 -----
 torch_modules/src/LSTMNetwork.cpp             | 120 -----
 torch_modules/src/ModularNetwork.cpp          |   8 +
 torch_modules/src/RawInputModule.cpp          |  45 +-
 torch_modules/src/SplitTransModule.cpp        |  42 +-
 16 files changed, 195 insertions(+), 800 deletions(-)
 delete mode 100644 torch_modules/include/GRUNetwork.hpp
 delete mode 100644 torch_modules/include/LSTMNetwork.hpp
 delete mode 100644 torch_modules/src/GRUNetwork.cpp
 delete mode 100644 torch_modules/src/LSTMNetwork.cpp

diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index 0684d56..29ff704 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -18,8 +18,6 @@ class Classifier
   private :
 
   void initNeuralNetwork(const std::vector<std::string> & definition);
-  void initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
-  void initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
   void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
 
   public :
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 113ee9b..64e4193 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -1,7 +1,5 @@
 #include "Classifier.hpp"
 #include "util.hpp"
-#include "LSTMNetwork.hpp"
-#include "GRUNetwork.hpp"
 #include "RandomNetwork.hpp"
 #include "ModularNetwork.hpp"
 
@@ -87,10 +85,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
 
   if (networkType == "Random")
     this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
-  else if (networkType == "LSTM")
-    initLSTM(definition, curIndex, nbOutputsPerState);
-  else if (networkType == "GRU")
-    initGRU(definition, curIndex, nbOutputsPerState);
   else if (networkType == "Modular")
     initModular(definition, curIndex, nbOutputsPerState);
   else
@@ -115,225 +109,6 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
 }
 
-void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
-{
-  int unknownValueThreshold;
-  std::vector<int> bufferContext, stackContext;
-  std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
-  std::vector<int> focusedBuffer, focusedStack;
-  std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
-  std::vector<int> maxNbElements;
-  std::vector<int> treeEmbeddingNbElems;
-  std::vector<std::pair<int, float>> mlp;
-  int rawInputLeftWindow, rawInputRightWindow;
-  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
-  bool bilstm, drop2d;
-  float lstmDropout, embeddingsDropout, totalInputDropout;
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
-        {
-          unknownValueThreshold = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            bufferContext.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            stackContext.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
-        {
-          columns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            focusedBuffer.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            focusedStack.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
-        {
-          focusedColumns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            maxNbElements.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
-        {
-          rawInputLeftWindow = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
-        {
-          rawInputRightWindow = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
-        {
-          embeddingsSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
-        {
-          auto params = util::split(sm.str(1), ' ');
-          if (params.size() % 2)
-            util::myThrow("MLP must have even number of parameters");
-          for (unsigned int i = 0; i < params.size()/2; i++)
-            mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
-        {
-          contextLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
-        {
-          focusedLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
-        {
-          rawInputLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
-        {
-          splitTransLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
-        {
-          nbLayers = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
-        {
-          bilstm = sm.str(1) == "true";
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
-        {
-          lstmDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
-        {
-          totalInputDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
-        {
-          embeddingsDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
-        {
-          drop2d = sm.str(1) == "true";
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
-        {
-          treeEmbeddingColumns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingBuffer.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingStack.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingNbElems.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
-        {
-          treeEmbeddingSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
-
-  this->nn.reset(new LSTMNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
-}
-
 void Classifier::loadOptimizer(std::filesystem::path path)
 {
   torch::load(*optimizer, path);
@@ -355,225 +130,6 @@ void Classifier::setState(const std::string & state)
   nn->setState(state);
 }
 
-void Classifier::initGRU(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
-{
-  int unknownValueThreshold;
-  std::vector<int> bufferContext, stackContext;
-  std::vector<std::string> columns, focusedColumns, treeEmbeddingColumns;
-  std::vector<int> focusedBuffer, focusedStack;
-  std::vector<int> treeEmbeddingBuffer, treeEmbeddingStack;
-  std::vector<int> maxNbElements;
-  std::vector<int> treeEmbeddingNbElems;
-  std::vector<std::pair<int, float>> mlp;
-  int rawInputLeftWindow, rawInputRightWindow;
-  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
-  bool bilstm, drop2d;
-  float lstmDropout, embeddingsDropout, totalInputDropout;
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
-        {
-          unknownValueThreshold = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Unknown value threshold :) unknownValueThreshold"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Buffer context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&bufferContext](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            bufferContext.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Buffer context :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Stack context :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&stackContext](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            stackContext.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Stack context :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&columns](auto sm)
-        {
-          columns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Columns :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedBuffer](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            focusedBuffer.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused buffer :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedStack](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            focusedStack.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused stack :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&focusedColumns](auto sm)
-        {
-          focusedColumns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused columns :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Max nb elements :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&maxNbElements](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            maxNbElements.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Max nb elements :) {size1 size2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input left window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLeftWindow](auto sm)
-        {
-          rawInputLeftWindow = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input left window :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Raw input right window :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputRightWindow](auto sm)
-        {
-          rawInputRightWindow = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw input right window :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsSize](auto sm)
-        {
-          embeddingsSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
-        {
-          auto params = util::split(sm.str(1), ' ');
-          if (params.size() % 2)
-            util::myThrow("MLP must have even number of parameters");
-          for (unsigned int i = 0; i < params.size()/2; i++)
-            mlp.emplace_back(std::make_pair(std::stoi(params[2*i]), std::stof(params[2*i+1])));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
-        {
-          contextLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Context LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Focused LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&focusedLSTMSize](auto sm)
-        {
-          focusedLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Focused LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Rawinput LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&rawInputLSTMSize](auto sm)
-        {
-          rawInputLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Raw LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Split trans LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&splitTransLSTMSize](auto sm)
-        {
-          splitTransLSTMSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Split trans LSTM size :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Num layers :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&nbLayers](auto sm)
-        {
-          nbLayers = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Num layers :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BiLSTM :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&bilstm](auto sm)
-        {
-          bilstm = sm.str(1) == "true";
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(BiLSTM :) true|false"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LSTM dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&lstmDropout](auto sm)
-        {
-          lstmDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
-        {
-          totalInputDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
-        {
-          embeddingsDropout = std::stof(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
-        {
-          drop2d = sm.str(1) == "true";
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
-        {
-          treeEmbeddingColumns = util::split(sm.str(1), ' ');
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding columns :) {column1 column2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding buffer :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingBuffer](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingBuffer.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding buffer :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding stack :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingStack](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingStack.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding stack :) {index1 index2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding nb :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingNbElems](auto sm)
-        {
-          for (auto & index : util::split(sm.str(1), ' '))
-            treeEmbeddingNbElems.emplace_back(std::stoi(index));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding nb :) {size1 size2...}"));
-
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&treeEmbeddingSize](auto sm)
-        {
-          treeEmbeddingSize = std::stoi(sm.str(1));
-          curIndex++;
-        }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
-
-  this->nn.reset(new GRUNetworkImpl(nbOutputsPerState, unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
-}
-
 void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
 {
   std::string anyBlanks = "(?:(?:\\s|\\t)*)";
diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
index cd6c33d..0d5cedd 100644
--- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
+++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
@@ -7,7 +7,7 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 
-class DepthLayerTreeEmbeddingModule : public Submodule
+class DepthLayerTreeEmbeddingModuleImpl : public Submodule
 {
   private :
 
@@ -15,16 +15,18 @@ class DepthLayerTreeEmbeddingModule : public Submodule
   std::vector<std::string> columns;
   std::vector<int> focusedBuffer;
   std::vector<int> focusedStack;
+  torch::nn::Embedding wordEmbeddings{nullptr};
   std::vector<std::shared_ptr<MyModule>> depthModules;
 
   public :
 
-  DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options);
+  DepthLayerTreeEmbeddingModuleImpl(const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
+TORCH_MODULE(DepthLayerTreeEmbeddingModule);
 
 #endif
 
diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 9c2732a..c105193 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -7,10 +7,11 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 
-class FocusedColumnModule : public Submodule
+class FocusedColumnModuleImpl : public Submodule
 {
   private :
 
+  torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   std::vector<int> focusedBuffer, focusedStack;
   std::string column;
@@ -18,12 +19,13 @@ class FocusedColumnModule : public Submodule
 
   public :
 
-  FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
+  FocusedColumnModuleImpl(const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
+TORCH_MODULE(FocusedColumnModule);
 
 #endif
 
diff --git a/torch_modules/include/GRUNetwork.hpp b/torch_modules/include/GRUNetwork.hpp
deleted file mode 100644
index ecff8a0..0000000
--- a/torch_modules/include/GRUNetwork.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-#ifndef GRUNETWORK__H
-#define GRUNETWORK__H
-
-#include "NeuralNetwork.hpp"
-#include "ContextModule.hpp"
-#include "RawInputModule.hpp"
-#include "SplitTransModule.hpp"
-#include "FocusedColumnModule.hpp"
-#include "DepthLayerTreeEmbeddingModule.hpp"
-#include "MLP.hpp"
-
-class GRUNetworkImpl : public NeuralNetworkImpl
-{
-//  private :
-//
-//  torch::nn::Embedding wordEmbeddings{nullptr};
-//  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
-//  torch::nn::Dropout embeddingsDropout{nullptr};
-//  torch::nn::Dropout inputDropout{nullptr};
-//
-//  MLP mlp{nullptr};
-//  ContextModule contextGRU{nullptr};
-//  RawInputModule rawInputGRU{nullptr};
-//  SplitTransModule splitTransGRU{nullptr};
-//  DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
-//  std::vector<FocusedColumnModule> focusedLstms;
-//  std::map<std::string,torch::nn::Linear> outputLayersPerState;
-
-  public :
-
-  GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
-  torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
-};
-
-#endif
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
deleted file mode 100644
index d83acf5..0000000
--- a/torch_modules/include/LSTMNetwork.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-#ifndef LSTMNETWORK__H
-#define LSTMNETWORK__H
-
-#include "NeuralNetwork.hpp"
-#include "ContextModule.hpp"
-#include "RawInputModule.hpp"
-#include "SplitTransModule.hpp"
-#include "FocusedColumnModule.hpp"
-#include "DepthLayerTreeEmbeddingModule.hpp"
-#include "MLP.hpp"
-
-class LSTMNetworkImpl : public NeuralNetworkImpl
-{
-//  private :
-//
-//  torch::nn::Embedding wordEmbeddings{nullptr};
-//  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
-//  torch::nn::Dropout embeddingsDropout{nullptr};
-//  torch::nn::Dropout inputDropout{nullptr};
-//
-//  MLP mlp{nullptr};
-//  ContextModule contextLSTM{nullptr};
-//  RawInputModule rawInputLSTM{nullptr};
-//  SplitTransModule splitTransLSTM{nullptr};
-//  DepthLayerTreeEmbeddingModule treeEmbedding{nullptr};
-//  std::vector<FocusedColumnModule> focusedLstms;
-//  std::map<std::string,torch::nn::Linear> outputLayersPerState;
-
-  public :
-
-  LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
-  torch::Tensor forward(torch::Tensor input) override;
-  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
-};
-
-#endif
diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp
index 871c8bb..931aca8 100644
--- a/torch_modules/include/ModularNetwork.hpp
+++ b/torch_modules/include/ModularNetwork.hpp
@@ -13,9 +13,9 @@ class ModularNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  torch::nn::Embedding wordEmbeddings{nullptr};
-  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
-  torch::nn::Dropout embeddingsDropout{nullptr};
+  //torch::nn::Embedding wordEmbeddings{nullptr};
+  //torch::nn::Dropout2d embeddingsDropout2d{nullptr};
+  //torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Dropout inputDropout{nullptr};
 
   MLP mlp{nullptr};
diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp
index 0134f97..4ded915 100644
--- a/torch_modules/include/RawInputModule.hpp
+++ b/torch_modules/include/RawInputModule.hpp
@@ -7,21 +7,23 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 
-class RawInputModule : public Submodule
+class RawInputModuleImpl : public Submodule
 {
   private :
 
+  torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int leftWindow, rightWindow;
 
   public :
 
-  RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
+  RawInputModuleImpl(const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
+TORCH_MODULE(RawInputModule);
 
 #endif
 
diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp
index 8be8e63..24c6841 100644
--- a/torch_modules/include/SplitTransModule.hpp
+++ b/torch_modules/include/SplitTransModule.hpp
@@ -7,21 +7,23 @@
 #include "LSTM.hpp"
 #include "GRU.hpp"
 
-class SplitTransModule : public Submodule
+class SplitTransModuleImpl : public Submodule
 {
   private :
 
+  torch::nn::Embedding wordEmbeddings{nullptr};
   std::shared_ptr<MyModule> myModule{nullptr};
   int maxNbTrans;
 
   public :
 
-  SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options);
+  SplitTransModuleImpl(int maxNbTrans, const std::string & definition);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
   void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
 };
+TORCH_MODULE(SplitTransModule);
 
 #endif
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index f3d7bb2..7e13cdc 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -1,15 +1,56 @@
 #include "DepthLayerTreeEmbeddingModule.hpp"
 
-DepthLayerTreeEmbeddingModule::DepthLayerTreeEmbeddingModule(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, MyModule::ModuleOptions options) :
-  maxElemPerDepth(maxElemPerDepth), columns(columns), focusedBuffer(focusedBuffer), focusedStack(focusedStack)
+DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(const std::string & definition)
 {
-  for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
-    depthModules.emplace_back(register_module(fmt::format("lstm_{}",i), LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)));
+  std::regex regex("(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)LayerSizes\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            columns = util::split(sm.str(1), ' ');
+
+            for (auto & index : util::split(sm.str(2), ' '))
+              focusedBuffer.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(3), ' '))
+              focusedStack.emplace_back(std::stoi(index));
+
+            for (auto & elem : util::split(sm.str(4), ' '))
+              maxElemPerDepth.emplace_back(std::stoi(elem));
+
+            auto subModuleType = sm.str(5);
+            auto subModuleArguments = util::split(sm.str(6), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int inSize = std::stoi(sm.str(7));
+            int outSize = std::stoi(sm.str(8));
+
+            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+
+            for (unsigned int i = 0; i < maxElemPerDepth.size(); i++)
+            {
+              std::string name = fmt::format("{}_{}", i, subModuleType);
+              if (subModuleType == "LSTM")
+                depthModules.emplace_back(register_module(name, LSTM(columns.size()*inSize, outSize, options)));
+              else if (subModuleType == "GRU")
+                depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
+              else
+                util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+            }
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
 }
 
-torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
+torch::Tensor DepthLayerTreeEmbeddingModuleImpl::forward(torch::Tensor input)
 {
-  auto context = input.narrow(1, firstInputIndex, getInputSize());
+  auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()));
 
   std::vector<torch::Tensor> outputs;
 
@@ -17,14 +58,14 @@ torch::Tensor DepthLayerTreeEmbeddingModule::forward(torch::Tensor input)
   for (unsigned int focused = 0; focused < focusedBuffer.size()+focusedStack.size(); focused++)
     for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++)
     {
-      outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({input.size(0), maxElemPerDepth[depth], (long)columns.size()*input.size(2)})));
+      outputs.emplace_back(depthModules[depth]->forward(context.narrow(1, offset, maxElemPerDepth[depth]*columns.size()).view({context.size(0), maxElemPerDepth[depth], (long)columns.size()*context.size(2)})));
       offset += maxElemPerDepth[depth]*columns.size();
     }
 
   return torch::cat(outputs, 1);
 }
 
-std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
+std::size_t DepthLayerTreeEmbeddingModuleImpl::getOutputSize()
 {
   std::size_t outputSize = 0;
 
@@ -34,7 +75,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getOutputSize()
   return outputSize*(focusedBuffer.size()+focusedStack.size());
 }
 
-std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
+std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize()
 {
   int inputSize = 0;
   for (int maxElem : maxElemPerDepth)
@@ -42,7 +83,7 @@ std::size_t DepthLayerTreeEmbeddingModule::getInputSize()
   return inputSize;
 }
 
-void DepthLayerTreeEmbeddingModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 717bc74..08cb9eb 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -1,30 +1,67 @@
 #include "FocusedColumnModule.hpp"
 
-FocusedColumnModule::FocusedColumnModule(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements)
+FocusedColumnModuleImpl::FocusedColumnModuleImpl(const std::string & definition)
 {
-  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            column = sm.str(1);
+            maxNbElements = std::stoi(sm.str(2));
+
+            for (auto & index : util::split(sm.str(3), ' '))
+              focusedBuffer.emplace_back(std::stoi(index));
+
+            for (auto & index : util::split(sm.str(4), ' '))
+              focusedStack.emplace_back(std::stoi(index));
+
+            auto subModuleType = sm.str(5);
+            auto subModuleArguments = util::split(sm.str(6), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int inSize = std::stoi(sm.str(7));
+            int outSize = std::stoi(sm.str(8));
+
+            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
 }
 
-torch::Tensor FocusedColumnModule::forward(torch::Tensor input)
+torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input)
 {
   std::vector<torch::Tensor> outputs;
   for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++)
-    outputs.emplace_back(myModule->forward(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)));
+    outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))));
 
   return torch::cat(outputs, 1);
 }
 
-std::size_t FocusedColumnModule::getOutputSize()
+std::size_t FocusedColumnModuleImpl::getOutputSize()
 {
   return (focusedBuffer.size()+focusedStack.size())*myModule->getOutputSize(maxNbElements);
 }
 
-std::size_t FocusedColumnModule::getInputSize()
+std::size_t FocusedColumnModuleImpl::getInputSize()
 {
   return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
 }
 
-void FocusedColumnModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   std::vector<long> focusedIndexes;
 
diff --git a/torch_modules/src/GRUNetwork.cpp b/torch_modules/src/GRUNetwork.cpp
deleted file mode 100644
index dc7f366..0000000
--- a/torch_modules/src/GRUNetwork.cpp
+++ /dev/null
@@ -1,120 +0,0 @@
-#include "GRUNetwork.hpp"
-
-GRUNetworkImpl::GRUNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextGRUSize, int focusedGRUSize, int rawInputGRUSize, int splitTransGRUSize, int numLayers, bool bigru, float gruDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
-{
-//  MyModule::ModuleOptions gruOptions{true,bigru,numLayers,gruDropout,false};
-//  auto gruOptionsAll = gruOptions;
-//  std::get<4>(gruOptionsAll) = true;
-//
-//  int currentOutputSize = embeddingsSize;
-//  int currentInputSize = 1;
-//
-//  contextGRU = register_module("contextGRU", ContextModule(columns, embeddingsSize, contextGRUSize, bufferContext, stackContext, gruOptions, unknownValueThreshold));
-//  contextGRU->setFirstInputIndex(currentInputSize);
-//  currentOutputSize += contextGRU->getOutputSize();
-//  currentInputSize += contextGRU->getInputSize();
-//
-//  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
-//  {
-//    rawInputGRU = register_module("rawInputGRU", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputGRUSize, gruOptionsAll));
-//    rawInputGRU->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += rawInputGRU->getOutputSize();
-//    currentInputSize += rawInputGRU->getInputSize();
-//  }
-//
-//  if (!treeEmbeddingColumns.empty())
-//  {
-//    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,gruOptions));
-//    treeEmbedding->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += treeEmbedding->getOutputSize();
-//    currentInputSize += treeEmbedding->getInputSize();
-//  }
-//
-//  splitTransGRU = register_module("splitTransGRU", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransGRUSize, gruOptionsAll));
-//  splitTransGRU->setFirstInputIndex(currentInputSize);
-//  currentOutputSize += splitTransGRU->getOutputSize();
-//  currentInputSize += splitTransGRU->getInputSize();
-//
-//  for (unsigned int i = 0; i < focusedColumns.size(); i++)
-//  {
-//    focusedLstms.emplace_back(register_module(fmt::format("GRU_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedGRUSize, gruOptions)));
-//    focusedLstms.back()->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += focusedLstms.back()->getOutputSize();
-//    currentInputSize += focusedLstms.back()->getInputSize();
-//  }
-//
-//  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-//  if (drop2d)
-//    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
-//  else
-//    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
-//  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
-//
-//  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
-//
-//  for (auto & it : nbOutputsPerState)
-//    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
-}
-
-torch::Tensor GRUNetworkImpl::forward(torch::Tensor input)
-{
-  return input;
-//  if (input.dim() == 1)
-//    input = input.unsqueeze(0);
-//
-//  auto embeddings = wordEmbeddings(input);
-//  if (embeddingsDropout2d.is_empty())
-//    embeddings = embeddingsDropout(embeddings);
-//  else
-//    embeddings = embeddingsDropout2d(embeddings);
-//
-//  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
-//
-//  outputs.emplace_back(contextGRU(embeddings));
-//
-//  if (!rawInputGRU.is_empty())
-//    outputs.emplace_back(rawInputGRU(embeddings));
-//
-//  if (!treeEmbedding.is_empty())
-//    outputs.emplace_back(treeEmbedding(embeddings));
-//
-//  outputs.emplace_back(splitTransGRU(embeddings));
-//
-//  for (auto & gru : focusedLstms)
-//    outputs.emplace_back(gru(embeddings));
-//
-//  auto totalInput = inputDropout(torch::cat(outputs, 1));
-//
-//  return outputLayersPerState.at(getState())(mlp(totalInput));
-}
-
-std::vector<std::vector<long>> GRUNetworkImpl::extractContext(Config & config, Dict & dict) const
-{
-  std::vector<std::vector<long>> context;
-  return context;
-//  if (dict.size() >= maxNbEmbeddings)
-//    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
-//
-//  context.emplace_back();
-//
-//  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
-//
-//  contextGRU->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!rawInputGRU.is_empty())
-//    rawInputGRU->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!treeEmbedding.is_empty())
-//    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  splitTransGRU->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  for (auto & gru : focusedLstms)
-//    gru->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!mustSplitUnknown() && context.size() > 1)
-//    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
-//
-//  return context;
-}
-
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
deleted file mode 100644
index e46d175..0000000
--- a/torch_modules/src/LSTMNetwork.cpp
+++ /dev/null
@@ -1,120 +0,0 @@
-#include "LSTMNetwork.hpp"
-
-LSTMNetworkImpl::LSTMNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
-{
-//  MyModule::ModuleOptions moduleOptions{true,bilstm,numLayers,lstmDropout,false};
-//  auto moduleOptionsAll = moduleOptions;
-//  std::get<4>(moduleOptionsAll) = true;
-//
-//  int currentOutputSize = embeddingsSize;
-//  int currentInputSize = 1;
-//
-//  contextLSTM = register_module("contextLSTM", ContextModule(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, moduleOptions, unknownValueThreshold));
-//  contextLSTM->setFirstInputIndex(currentInputSize);
-//  currentOutputSize += contextLSTM->getOutputSize();
-//  currentInputSize += contextLSTM->getInputSize();
-//
-//  if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
-//  {
-//    rawInputLSTM = register_module("rawInputLSTM", RawInputModule(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, moduleOptionsAll));
-//    rawInputLSTM->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += rawInputLSTM->getOutputSize();
-//    currentInputSize += rawInputLSTM->getInputSize();
-//  }
-//
-//  if (!treeEmbeddingColumns.empty())
-//  {
-//    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbeddingModule(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,moduleOptions));
-//    treeEmbedding->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += treeEmbedding->getOutputSize();
-//    currentInputSize += treeEmbedding->getInputSize();
-//  }
-//
-//  splitTransLSTM = register_module("splitTransLSTM", SplitTransModule(Config::maxNbAppliableSplitTransitions, embeddingsSize, splitTransLSTMSize, moduleOptionsAll));
-//  splitTransLSTM->setFirstInputIndex(currentInputSize);
-//  currentOutputSize += splitTransLSTM->getOutputSize();
-//  currentInputSize += splitTransLSTM->getInputSize();
-//
-//  for (unsigned int i = 0; i < focusedColumns.size(); i++)
-//  {
-//    focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnModule(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, moduleOptions)));
-//    focusedLstms.back()->setFirstInputIndex(currentInputSize);
-//    currentOutputSize += focusedLstms.back()->getOutputSize();
-//    currentInputSize += focusedLstms.back()->getInputSize();
-//  }
-//
-//  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-//  if (drop2d)
-//    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
-//  else
-//    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
-//  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
-//
-//  mlp = register_module("mlp", MLP(currentOutputSize, mlpParams));
-//
-//  for (auto & it : nbOutputsPerState)
-//    outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
-}
-
-torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
-{
-  return input;
-//  if (input.dim() == 1)
-//    input = input.unsqueeze(0);
-//
-//  auto embeddings = wordEmbeddings(input);
-//  if (embeddingsDropout2d.is_empty())
-//    embeddings = embeddingsDropout(embeddings);
-//  else
-//    embeddings = embeddingsDropout2d(embeddings);
-//
-//  std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
-//
-//  outputs.emplace_back(contextLSTM(embeddings));
-//
-//  if (!rawInputLSTM.is_empty())
-//    outputs.emplace_back(rawInputLSTM(embeddings));
-//
-//  if (!treeEmbedding.is_empty())
-//    outputs.emplace_back(treeEmbedding(embeddings));
-//
-//  outputs.emplace_back(splitTransLSTM(embeddings));
-//
-//  for (auto & lstm : focusedLstms)
-//    outputs.emplace_back(lstm(embeddings));
-//
-//  auto totalInput = inputDropout(torch::cat(outputs, 1));
-//
-//  return outputLayersPerState.at(getState())(mlp(totalInput));
-}
-
-std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
-{
-  std::vector<std::vector<long>> context;
-  return context;
-//  if (dict.size() >= maxNbEmbeddings)
-//    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
-//
-//  context.emplace_back();
-//
-//  context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
-//
-//  contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!rawInputLSTM.is_empty())
-//    rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!treeEmbedding.is_empty())
-//    treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  for (auto & lstm : focusedLstms)
-//    lstm->addToContext(context, dict, config, mustSplitUnknown());
-//
-//  if (!mustSplitUnknown() && context.size() > 1)
-//    util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
-//
-//  return context;
-}
-
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 21edb84..47bf5b1 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -23,6 +23,14 @@ ModularNetworkImpl::ModularNetworkImpl(std::map<std::string,std::size_t> nbOutpu
     std::string name = fmt::format("{}_{}", modules.size(), splited.first);
     if (splited.first == "Context")
       modules.emplace_back(register_module(name, ContextModule(splited.second)));
+    else if (splited.first == "Focused")
+      modules.emplace_back(register_module(name, FocusedColumnModule(splited.second)));
+    else if (splited.first == "RawInput")
+      modules.emplace_back(register_module(name, RawInputModule(splited.second)));
+    else if (splited.first == "SplitTrans")
+      modules.emplace_back(register_module(name, SplitTransModule(Config::maxNbAppliableSplitTransitions, splited.second)));
+    else if (splited.first == "DepthLayerTree")
+      modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(splited.second)));
     else if (splited.first == "MLP")
     {
       mlpDef = splited.second;
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index e14451d..9c5e541 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -1,26 +1,57 @@
 #include "RawInputModule.hpp"
 
-RawInputModule::RawInputModule(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : leftWindow(leftWindow), rightWindow(rightWindow)
+RawInputModuleImpl::RawInputModuleImpl(const std::string & definition)
 {
-  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+  std::regex regex("(?:(?:\\s|\\t)*)Left\\{(.*)\\}(?:(?:\\s|\\t)*)Right\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            leftWindow = std::stoi(sm.str(1));
+            rightWindow = std::stoi(sm.str(2));
+
+            auto subModuleType = sm.str(3);
+            auto subModuleArguments = util::split(sm.str(4), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int inSize = std::stoi(sm.str(5));
+            int outSize = std::stoi(sm.str(6));
+
+            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
 }
 
-torch::Tensor RawInputModule::forward(torch::Tensor input)
+torch::Tensor RawInputModuleImpl::forward(torch::Tensor input)
 {
-  return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
+  return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
 }
 
-std::size_t RawInputModule::getOutputSize()
+std::size_t RawInputModuleImpl::getOutputSize()
 {
   return myModule->getOutputSize(leftWindow + rightWindow + 1);
 }
 
-std::size_t RawInputModule::getInputSize()
+std::size_t RawInputModuleImpl::getInputSize()
 {
   return leftWindow + rightWindow + 1;
 }
 
-void RawInputModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   if (leftWindow < 0 or rightWindow < 0)
     return;
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index 45a48df..ab1276c 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -1,27 +1,55 @@
 #include "SplitTransModule.hpp"
 #include "Transition.hpp"
 
-SplitTransModule::SplitTransModule(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, MyModule::ModuleOptions options) : maxNbTrans(maxNbTrans)
+SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & definition)
 {
-  myModule = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options));
+  std::regex regex("(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
+        {
+          try
+          {
+            auto subModuleType = sm.str(1);
+            auto subModuleArguments = util::split(sm.str(2), ' ');
+
+            auto options = MyModule::ModuleOptions(true)
+              .bidirectional(std::stoi(subModuleArguments[0]))
+              .num_layers(std::stoi(subModuleArguments[1]))
+              .dropout(std::stof(subModuleArguments[2]))
+              .complete(std::stoi(subModuleArguments[3]));
+
+            int inSize = std::stoi(sm.str(3));
+            int outSize = std::stoi(sm.str(4));
+
+            wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
+
+            if (subModuleType == "LSTM")
+              myModule = register_module("myModule", LSTM(inSize, outSize, options));
+            else if (subModuleType == "GRU")
+              myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else
+              util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
+
+          } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
+        }))
+    util::myThrow(fmt::format("invalid definition '{}'", definition));
 }
 
-torch::Tensor SplitTransModule::forward(torch::Tensor input)
+torch::Tensor SplitTransModuleImpl::forward(torch::Tensor input)
 {
-  return myModule->forward(input.narrow(1, firstInputIndex, getInputSize()));
+  return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize())));
 }
 
-std::size_t SplitTransModule::getOutputSize()
+std::size_t SplitTransModuleImpl::getOutputSize()
 {
   return myModule->getOutputSize(maxNbTrans);
 }
 
-std::size_t SplitTransModule::getInputSize()
+std::size_t SplitTransModuleImpl::getInputSize()
 {
   return maxNbTrans;
 }
 
-void SplitTransModule::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
+void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
 {
   auto & splitTransitions = config.getAppliableSplitTransitions();
   for (auto & contextElement : context)
-- 
GitLab