Skip to content
Snippets Groups Projects
Commit f00fdc1f authored by Franck Dary's avatar Franck Dary
Browse files

Added dropouts to CNNNetwork

parent 983dc489
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout cnnDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
......
......@@ -21,6 +21,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
for (auto & col : focusedColumns)
......@@ -67,9 +69,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
auto totalInput = torch::cat(cnnOutputs, 1);
auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1));
return linear2(torch::relu(linear1(totalInput)));
return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
}
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
......
......@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
{
std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (!config.isComment(index))
if (!config.isCommentPredicted(index))
leftContext.push(index);
std::vector<long> context;
......@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (!config.isComment(index))
if (!config.isCommentPredicted(index))
context.emplace_back(index);
while (context.size() < leftBorder+rightBorder+1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment