From 4bebd4a82cf13e8c7cac23b8c062695d401549ee Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 24 Feb 2020 18:07:10 +0100 Subject: [PATCH] Added dropout layer to ConcatWordsNetwork --- torch_modules/include/ConcatWordsNetwork.hpp | 1 + torch_modules/src/ConcatWordsNetwork.cpp | 3 ++- trainer/include/Trainer.hpp | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp index 064a00e..7152eba 100644 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -10,6 +10,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; + torch::nn::Dropout dropout{nullptr}; public : diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index b72694c..fd9f2b8 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -10,12 +10,13 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); + dropout = register_module("dropout", torch::nn::Dropout(0.3)); } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} - auto wordsAsEmb = wordEmbeddings(input); + auto wordsAsEmb = dropout(wordEmbeddings(input)); // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index a63c977..177f5f2 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -18,7 +18,7 @@ class Trainer DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{100}; + int batchSize{50}; int nbExamples{0}; public : -- GitLab