From f00fdc1f947d6b214d854c3786b72c44124ec45b Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 10 Mar 2020 16:14:07 +0100
Subject: [PATCH] Added dropouts to CNNNetwork

---
 torch_modules/include/CNNNetwork.hpp | 2 ++
 torch_modules/src/CNNNetwork.cpp     | 6 ++++--
 torch_modules/src/NeuralNetwork.cpp  | 4 ++--
 3 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index 2c4b507..cab39f0 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
+  torch::nn::Dropout cnnDropout{nullptr};
+  torch::nn::Dropout hiddenDropout{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
   CNN contextCNN{nullptr};
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 6acdc8b..8889fe0 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
   embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
+  cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
+  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
   contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
   int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
   for (auto & col : focusedColumns)
@@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
 
   cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
 
-  auto totalInput = torch::cat(cnnOutputs, 1);
+  auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
 
-  return linear2(torch::relu(linear1(totalInput)));
+  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
 }
 
 std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 0ef9f8d..37c206c 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
 {
   std::stack<long> leftContext;
   for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
-    if (!config.isComment(index))
+    if (!config.isCommentPredicted(index))
       leftContext.push(index);
 
   std::vector<long> context;
@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
   }
 
   for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
-    if (!config.isComment(index))
+    if (!config.isCommentPredicted(index))
       context.emplace_back(index);
 
   while (context.size() < leftBorder+rightBorder+1)
-- 
GitLab