diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp index 064a00eda70334f9e104db793b6476e94f43e36c..7152eba2e06aa74b909b2136514d64120b6c0b8d 100644 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ b/torch_modules/include/ConcatWordsNetwork.hpp @@ -10,6 +10,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; + torch::nn::Dropout dropout{nullptr}; public : diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index b72694c4fa296a2215f8c3d65991977a2e69e6dc..fd9f2b8537a6685bbca24e2d664f570160c4dfe0 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -10,12 +10,13 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); + dropout = register_module("dropout", torch::nn::Dropout(0.3)); } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} - auto wordsAsEmb = wordEmbeddings(input); + auto wordsAsEmb = dropout(wordEmbeddings(input)); // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index a63c977799d1b2b334baca1153e1f4dd2bc9bf40..177f5f28e4feb0ecdeafe3a64bfffa0b11a63c5b 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -18,7 +18,7 @@ class Trainer DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{100}; + int batchSize{50}; int nbExamples{0}; public :