Skip to content
Snippets Groups Projects
Commit 2261c98b authored by Franck Dary's avatar Franck Dary
Browse files

Fixed bug in RawInputLSTM where only last context was added to. Applying dropout after relu in MLP

parent f6de0f30
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
MLP mlp{nullptr};
ContextLSTM contextLSTM{nullptr};
......
......@@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input)
{
torch::Tensor output = input;
for (unsigned int i = 0; i < layers.size()-1; i++)
output = torch::relu(dropouts[i](layers[i](output)));
output = dropouts[i](torch::relu(layers[i](output)));
return layers.back()(output);
}
......
......@@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di
if (leftWindow < 0 or rightWindow < 0)
return;
for (int i = 0; i < leftWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (auto & contextElement : context)
{
for (int i = 0; i < leftWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
else
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment