diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 48eb5b11a2d8671c0c87301285e0e95c668ecbe8..bc0da8ef3dec552c7459d14d03ca6e196fe2c84d 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -8,6 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); + machine.getClassifier()->getNN()->train(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index f193ebc9233f7606ce825f332b7cb9dfd25683b5..2c4b507e6dc22ce1db32d5a7614d08519db24a94 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -17,6 +17,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl int rawInputSize; torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 285ad8c40fcf7ba164a7a32c2dc5501f02ab0af7..6acdc8bbfe573977781d917f002ea2f30261075f 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -20,6 +20,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; for (auto & col : focusedColumns) @@ -37,7 +38,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - auto embeddings = wordEmbeddings(input); + auto embeddings = embeddingsDropout(wordEmbeddings(input)); auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder)); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 2b68072465a599038ba0420003514bf9581f2db3..0b1b3406af3d4483480f31dcc7ba625cb24f8a69 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -91,6 +91,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); + machine.getClassifier()->getNN()->train(train); auto lossFct = torch::nn::CrossEntropyLoss();