Skip to content
Snippets Groups Projects
Commit 983dc489 authored by Franck Dary's avatar Franck Dary
Browse files

Added Dropout to CNNNetwork

parent 1efb791a
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{
torch::AutoGradMode useGrad(false);
machine.getClassifier()->getNN()->train(false);
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
......
......@@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
......
......@@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns)
......@@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
if (input.dim() == 1)
input = input.unsqueeze(0);
auto embeddings = wordEmbeddings(input);
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
......
......@@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
machine.getClassifier()->getNN()->train(train);
auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment