diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp
index 064a00eda70334f9e104db793b6476e94f43e36c..7152eba2e06aa74b909b2136514d64120b6c0b8d 100644
--- a/torch_modules/include/ConcatWordsNetwork.hpp
+++ b/torch_modules/include/ConcatWordsNetwork.hpp
@@ -10,6 +10,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
+  torch::nn::Dropout dropout{nullptr};
 
   public :
 
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index b72694c4fa296a2215f8c3d65991977a2e69e6dc..fd9f2b8537a6685bbca24e2d664f570160c4dfe0 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -10,12 +10,13 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
   linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
   linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
+  dropout = register_module("dropout", torch::nn::Dropout(0.3));
 }
 
 torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
 {
   // input dim = {batch, sequence, embeddings}
-  auto wordsAsEmb = wordEmbeddings(input);
+  auto wordsAsEmb = dropout(wordEmbeddings(input));
   // reshaped dim = {batch, sequence of embeddings}
   auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index a63c977799d1b2b334baca1153e1f4dd2bc9bf40..177f5f28e4feb0ecdeafe3a64bfffa0b11a63c5b 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -18,7 +18,7 @@ class Trainer
   DataLoader dataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
-  int batchSize{100};
+  int batchSize{50};
   int nbExamples{0};
 
   public :