diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index 2c4b507e6dc22ce1db32d5a7614d08519db24a94..cab39f07267ed4422e441ceafd8bd0c2d6adf85b 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; + torch::nn::Dropout cnnDropout{nullptr}; + torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 6acdc8bbfe573977781d917f002ea2f30261075f..8889fe09792f593d5f5985297b563746668126a4 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); + cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3)); + hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; for (auto & col : focusedColumns) @@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); - auto totalInput = torch::cat(cnnOutputs, 1); + auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1)); - return linear2(torch::relu(linear1(totalInput))); + return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 0ef9f8dd03b2be4e90afedea75fe5cedfb66ce41..37c206cf1baf43d2d3d230529e88f1e7271a48ca 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config { std::stack<long> leftContext; for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) - if (!config.isComment(index)) + if (!config.isCommentPredicted(index)) leftContext.push(index); std::vector<long> context; @@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config } for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) - if (!config.isComment(index)) + if (!config.isCommentPredicted(index)) context.emplace_back(index); while (context.size() < leftBorder+rightBorder+1)