From 983dc48944f2517fe1a4f33bb694f8c5ddb571d1 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 9 Mar 2020 22:27:45 +0100
Subject: [PATCH] Added Dropout to CNNNetwork

---
 decoder/src/Decoder.cpp              | 1 +
 torch_modules/include/CNNNetwork.hpp | 1 +
 torch_modules/src/CNNNetwork.cpp     | 3 ++-
 trainer/src/Trainer.cpp              | 1 +
 4 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 48eb5b1..bc0da8e 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
 {
   torch::AutoGradMode useGrad(false);
+  machine.getClassifier()->getNN()->train(false);
   config.addPredicted(machine.getPredicted());
 
   constexpr int printInterval = 50;
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index f193ebc..2c4b507 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
   int rawInputSize;
 
   torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Dropout embeddingsDropout{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
   CNN contextCNN{nullptr};
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 285ad8c..6acdc8b 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
   int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
   contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
   int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
   for (auto & col : focusedColumns)
@@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  auto embeddings = wordEmbeddings(input);
+  auto embeddings = embeddingsDropout(wordEmbeddings(input));
 
   auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
   context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 2b68072..0b1b340 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
   int currentBatchNumber = 0;
 
   torch::AutoGradMode useGrad(train);
+  machine.getClassifier()->getNN()->train(train);
 
   auto lossFct = torch::nn::CrossEntropyLoss();
 
-- 
GitLab