Skip to content
Snippets Groups Projects
Commit 983dc489 authored by Franck Dary's avatar Franck Dary
Browse files

Added Dropout to CNNNetwork

parent 1efb791a
Branches
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.getClassifier()->getNN()->train(false);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50; constexpr int printInterval = 50;
......
...@@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
int rawInputSize; int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr}; CNN contextCNN{nullptr};
......
...@@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns) for (auto & col : focusedColumns)
...@@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
if (input.dim() == 1) if (input.dim() == 1)
input = input.unsqueeze(0); input = input.unsqueeze(0);
auto embeddings = wordEmbeddings(input); auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
......
...@@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
int currentBatchNumber = 0; int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train); torch::AutoGradMode useGrad(train);
machine.getClassifier()->getNN()->train(train);
auto lossFct = torch::nn::CrossEntropyLoss(); auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment