From 983dc48944f2517fe1a4f33bb694f8c5ddb571d1 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 9 Mar 2020 22:27:45 +0100 Subject: [PATCH] Added Dropout to CNNNetwork --- decoder/src/Decoder.cpp | 1 + torch_modules/include/CNNNetwork.hpp | 1 + torch_modules/src/CNNNetwork.cpp | 3 ++- trainer/src/Trainer.cpp | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 48eb5b1..bc0da8e 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); + machine.getClassifier()->getNN()->train(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index f193ebc..2c4b507 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl int rawInputSize; torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 285ad8c..6acdc8b 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; for (auto & col : focusedColumns) @@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - auto embeddings = wordEmbeddings(input); + auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 2b68072..0b1b340 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); + machine.getClassifier()->getNN()->train(train); auto lossFct = torch::nn::CrossEntropyLoss(); -- GitLab