diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 48eb5b11a2d8671c0c87301285e0e95c668ecbe8..bc0da8ef3dec552c7459d14d03ca6e196fe2c84d 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
 {
   torch::AutoGradMode useGrad(false);
+  machine.getClassifier()->getNN()->train(false);
   config.addPredicted(machine.getPredicted());
 
   constexpr int printInterval = 50;
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index f193ebc9233f7606ce825f332b7cb9dfd25683b5..2c4b507e6dc22ce1db32d5a7614d08519db24a94 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
   int rawInputSize;
 
   torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
   CNN contextCNN{nullptr};
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 285ad8c40fcf7ba164a7a32c2dc5501f02ab0af7..6acdc8bbfe573977781d917f002ea2f30261075f 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
   int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
   contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
   int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
   for (auto & col : focusedColumns)
@@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  auto embeddings = wordEmbeddings(input);
+  auto embeddings = embeddingsDropout(wordEmbeddings(input));
 
   auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
   context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 2b68072465a599038ba0420003514bf9581f2db3..0b1b3406af3d4483480f31dcc7ba625cb24f8a69 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   int currentBatchNumber = 0;
 
   torch::AutoGradMode useGrad(train);
+  machine.getClassifier()->getNN()->train(train);
 
   auto lossFct = torch::nn::CrossEntropyLoss();