Skip to content
Snippets Groups Projects
Commit f00fdc1f authored by Franck Dary's avatar Franck Dary
Browse files

Added dropouts to CNNNetwork

parent 983dc489
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout cnnDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr}; CNN contextCNN{nullptr};
......
...@@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns) for (auto & col : focusedColumns)
...@@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
auto totalInput = torch::cat(cnnOutputs, 1); auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
return linear2(torch::relu(linear1(totalInput))); return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
} }
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
......
...@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config ...@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
{ {
std::stack<long> leftContext; std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (!config.isComment(index)) if (!config.isCommentPredicted(index))
leftContext.push(index); leftContext.push(index);
std::vector<long> context; std::vector<long> context;
...@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config ...@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
} }
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (!config.isComment(index)) if (!config.isCommentPredicted(index))
context.emplace_back(index); context.emplace_back(index);
while (context.size() < leftBorder+rightBorder+1) while (context.size() < leftBorder+rightBorder+1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment