From 4bebd4a82cf13e8c7cac23b8c062695d401549ee Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 24 Feb 2020 18:07:10 +0100
Subject: [PATCH] Added dropout layer to ConcatWordsNetwork

---
 torch_modules/include/ConcatWordsNetwork.hpp | 1 +
 torch_modules/src/ConcatWordsNetwork.cpp     | 3 ++-
 trainer/include/Trainer.hpp                  | 2 +-
 3 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp
index 064a00e..7152eba 100644
--- a/torch_modules/include/ConcatWordsNetwork.hpp
+++ b/torch_modules/include/ConcatWordsNetwork.hpp
@@ -10,6 +10,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
+  torch::nn::Dropout dropout{nullptr};
 
   public :
 
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index b72694c..fd9f2b8 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -10,12 +10,13 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
   linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
   linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
+  dropout = register_module("dropout", torch::nn::Dropout(0.3));
 }
 
 torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
 {
   // input dim = {batch, sequence, embeddings}
-  auto wordsAsEmb = wordEmbeddings(input);
+  auto wordsAsEmb = dropout(wordEmbeddings(input));
   // reshaped dim = {batch, sequence of embeddings}
   auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
 
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index a63c977..177f5f2 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -18,7 +18,7 @@ class Trainer
   DataLoader dataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
-  int batchSize{100};
+  int batchSize{50};
   int nbExamples{0};
 
   public :
-- 
GitLab