diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index ce9c91e585216b09a5a2624fd8631434188cb8a5..089ee2bba666d42a759d8ba21c92a38b40727446 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -75,8 +75,9 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
   std::vector<int> focusedBuffer, focusedStack;
   std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
+  std::vector<std::pair<int, float>> mlp;
   int rawInputLeftWindow, rawInputRightWindow;
-  int embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
+  int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers;
   bool bilstm;
   float lstmDropout;
 
@@ -162,12 +163,16 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings size :) value"));
 
-  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Hidden size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&hiddenSize](auto sm)
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:MLP :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&mlp](auto sm)
         {
-          hiddenSize = std::stoi(sm.str(1));
+          auto params = util::split(sm.str(1), ' ');
+          if (params.size() % 2)
+            util::myThrow("MLP must have even number of parameters");
+          for (unsigned int i = 0; i < params.size()/2; i++)
+            mlp.emplace_back(std::make_pair(std::stoi(params[i]), std::stof(params[i+1])));
           curIndex++;
         }))
-    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Hidden size :) value"));
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(MLP :) {hidden1 dropout1 hidden2 dropout2...}"));
 
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Context LSTM size :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&contextLSTMSize](auto sm)
         {
@@ -218,6 +223,6 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, hiddenSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
+  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout));
 }
 
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 6f2b46fc2cb73bf44ff73a7c8413175d4e772451..07d689effc2a6672ca78ae4ba91ad281b39204af 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -6,6 +6,7 @@
 #include "RawInputLSTM.hpp"
 #include "SplitTransLSTM.hpp"
 #include "FocusedColumnLSTM.hpp"
+#include "MLP.hpp"
 
 class LSTMNetworkImpl : public NeuralNetworkImpl
 {
@@ -14,10 +15,8 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Dropout lstmDropout{nullptr};
-  torch::nn::Dropout hiddenDropout{nullptr};
-  torch::nn::Linear linear1{nullptr};
-  torch::nn::Linear linear2{nullptr};
 
+  MLP mlp{nullptr};
   ContextLSTM contextLSTM{nullptr};
   RawInputLSTM rawInputLSTM{nullptr};
   SplitTransLSTM splitTransLSTM{nullptr};
@@ -27,7 +26,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..71520f2be3712e78cd9fe4987d50ade32d289825
--- /dev/null
+++ b/torch_modules/include/MLP.hpp
@@ -0,0 +1,21 @@
+#ifndef MLP__H
+#define MLP__H
+
+#include <torch/torch.h>
+
+class MLPImpl : public torch::nn::Module
+{
+  private :
+
+  std::vector<torch::nn::Linear> layers;
+  std::vector<torch::nn::Dropout> dropouts;
+
+  public :
+
+  MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params);
+  torch::Tensor forward(torch::Tensor input);
+};
+TORCH_MODULE(MLP);
+
+#endif
+
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index cfa8c458679bfefea76ea62b239c08fa63428b5e..ab06806692b716da93056ed7091f0a6daf7f2ba5 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, int hiddenSize, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -38,10 +38,8 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
-  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
 
-  linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize));
-  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
+  mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
 }
 
 torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
@@ -65,7 +63,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
 
   auto totalInput = torch::cat(outputs, 1);
 
-  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
+  return mlp(totalInput);
 }
 
 std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..52a4c235d5489165741ab421995ce9055401e67b
--- /dev/null
+++ b/torch_modules/src/MLP.cpp
@@ -0,0 +1,26 @@
+#include "MLP.hpp"
+#include "fmt/core.h"
+
+MLPImpl::MLPImpl(int inputSize, int outputSize, std::vector<std::pair<int, float>> params)
+{
+  int inSize = inputSize;
+
+  for (auto & param : params)
+  {
+    layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, param.first)));
+    dropouts.emplace_back(register_module(fmt::format("dropout_{}", dropouts.size()), torch::nn::Dropout(param.second)));
+    inSize = param.first;
+  }
+
+  layers.emplace_back(register_module(fmt::format("layer_{}", layers.size()), torch::nn::Linear(inSize, outputSize)));
+}
+
+torch::Tensor MLPImpl::forward(torch::Tensor input)
+{
+  torch::Tensor output = input;
+  for (unsigned int i = 0; i < layers.size()-1; i++)
+    output = torch::relu(dropouts[i](layers[i](output)));
+
+  return layers.back()(output);
+}
+