Skip to content
Snippets Groups Projects
Commit 2261c98b authored by Franck Dary's avatar Franck Dary
Browse files

Fixed bug in RawInputLSTM where only last context was added to. Applying dropout after relu in MLP

parent f6de0f30
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ...@@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
MLP mlp{nullptr}; MLP mlp{nullptr};
ContextLSTM contextLSTM{nullptr}; ContextLSTM contextLSTM{nullptr};
......
...@@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input) ...@@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input)
{ {
torch::Tensor output = input; torch::Tensor output = input;
for (unsigned int i = 0; i < layers.size()-1; i++) for (unsigned int i = 0; i < layers.size()-1; i++)
output = torch::relu(dropouts[i](layers[i](output))); output = dropouts[i](torch::relu(layers[i](output)));
return layers.back()(output); return layers.back()(output);
} }
......
...@@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di ...@@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di
if (leftWindow < 0 or rightWindow < 0) if (leftWindow < 0 or rightWindow < 0)
return; return;
for (int i = 0; i < leftWindow; i++) for (auto & contextElement : context)
if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) {
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); for (int i = 0; i < leftWindow; i++)
else if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
else
for (int i = 0; i <= rightWindow; i++) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); for (int i = 0; i <= rightWindow; i++)
else if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
else
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment