From f00fdc1f947d6b214d854c3786b72c44124ec45b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 10 Mar 2020 16:14:07 +0100 Subject: [PATCH] Added dropouts to CNNNetwork --- torch_modules/include/CNNNetwork.hpp | 2 ++ torch_modules/src/CNNNetwork.cpp | 6 ++++-- torch_modules/src/NeuralNetwork.cpp | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 2c4b507..cab39f0 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; + torch::nn::Dropout cnnDropout{nullptr}; + torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 6acdc8b..8889fe0 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); + cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3)); + hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; for (auto & col : focusedColumns) @@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); - auto totalInput = torch::cat(cnnOutputs, 1); + auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1)); - return linear2(torch::relu(linear1(totalInput))); + return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 0ef9f8d..37c206c 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config { std::stack<long> leftContext; for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) - if (!config.isComment(index)) + if (!config.isCommentPredicted(index)) leftContext.push(index); std::vector<long> context; @@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config } for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) - if (!config.isComment(index)) + if (!config.isCommentPredicted(index)) context.emplace_back(index); while (context.size() < leftBorder+rightBorder+1) -- GitLab