Skip to content
Snippets Groups Projects
Commit 4bebd4a8 authored by Franck Dary's avatar Franck Dary
Browse files

Added dropout layer to ConcatWordsNetwork

parent fb4e0869
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
public :
......
......@@ -10,12 +10,13 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
auto wordsAsEmb = dropout(wordEmbeddings(input));
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
......
......@@ -18,7 +18,7 @@ class Trainer
DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{100};
int batchSize{50};
int nbExamples{0};
public :
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment