diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 89c7bb3dd8e037c9f1674f2a684729e75761373c..b3e0710eb4e23dc155c629618de8a63f95c522fb 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -95,7 +95,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
   std::vector<std::pair<int, float>> mlp;
   int rawInputLeftWindow, rawInputRightWindow;
   int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
-  bool bilstm;
+  bool bilstm, drop2d;
   float lstmDropout, embeddingsDropout, totalInputDropout;
 
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
@@ -254,6 +254,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value"));
 
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm)
+        {
+          drop2d = sm.str(1) == "true";
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false"));
+
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm)
         {
           treeEmbeddingColumns = util::split(sm.str(1), ' ');
@@ -292,7 +299,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout));
+  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d));
 }
 
 void Classifier::loadOptimizer(std::filesystem::path path)
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 550fae1480e7be362aa7c04f502b591753a6279b..e742b9723a3ba5c9e14485dac5f5f4a9dc46c5ad 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -14,6 +14,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
   private :
 
   torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Dropout2d embeddingsDropout2d{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Dropout inputDropout{nullptr};
 
@@ -24,12 +25,9 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
   DepthLayerTreeEmbedding treeEmbedding{nullptr};
   std::vector<FocusedColumnLSTM> focusedLstms;
 
-  bool hasRawInputLSTM{false};
-  bool hasTreeEmbedding{false};
-
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout);
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 98cbb08ffc18089af29216bb219974ee672517cc..cfa004e58e66b3162231fd4046b499ac62dc5636 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -16,7 +16,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
 
   if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0)
   {
-    hasRawInputLSTM = true;
     rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll));
     rawInputLSTM->setFirstInputIndex(currentInputSize);
     currentOutputSize += rawInputLSTM->getOutputSize();
@@ -25,7 +24,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
 
   if (!treeEmbeddingColumns.empty())
   {
-    hasTreeEmbedding = true;
     treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
     treeEmbedding->setFirstInputIndex(currentInputSize);
     currentOutputSize += treeEmbedding->getOutputSize();
@@ -46,7 +44,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
   }
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
-  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
+  if (drop2d)
+    embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue));
+  else
+    embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
   inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
 
   mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
@@ -57,16 +58,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  auto embeddings = embeddingsDropout(wordEmbeddings(input));
+  auto embeddings = wordEmbeddings(input);
+  if (embeddingsDropout2d.is_empty())
+    embeddings = embeddingsDropout(embeddings);
+  else
+    embeddings = embeddingsDropout2d(embeddings);
 
   std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)};
 
   outputs.emplace_back(contextLSTM(embeddings));
 
-  if (hasRawInputLSTM)
+  if (!rawInputLSTM.is_empty())
     outputs.emplace_back(rawInputLSTM(embeddings));
 
-  if (hasTreeEmbedding)
+  if (!treeEmbedding.is_empty())
     outputs.emplace_back(treeEmbedding(embeddings));
 
   outputs.emplace_back(splitTransLSTM(embeddings));
@@ -91,10 +96,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
 
   contextLSTM->addToContext(context, dict, config);
 
-  if (hasRawInputLSTM)
+  if (!rawInputLSTM.is_empty())
     rawInputLSTM->addToContext(context, dict, config);
 
-  if (hasTreeEmbedding)
+  if (!treeEmbedding.is_empty())
     treeEmbedding->addToContext(context, dict, config);
 
   splitTransLSTM->addToContext(context, dict, config);