From 2261c98b655864e2f475ce65e6cd2542f5d1aeef Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 Apr 2020 20:02:36 +0200 Subject: [PATCH] Fixed bug in RawInputLSTM where only last context was added to. Applying dropout after relu in MLP --- torch_modules/include/LSTMNetwork.hpp | 1 - torch_modules/src/MLP.cpp | 2 +- torch_modules/src/RawInputLSTM.cpp | 25 ++++++++++++++----------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index f0b58dc..7a6576c 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; - torch::nn::Dropout lstmDropout{nullptr}; MLP mlp{nullptr}; ContextLSTM contextLSTM{nullptr}; diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index 52a4c23..b886245 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input) { torch::Tensor output = input; for (unsigned int i = 0; i < layers.size()-1; i++) - output = torch::relu(dropouts[i](layers[i](output))); + output = dropouts[i](torch::relu(layers[i](output))); return layers.back()(output); } diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp index ebcfbfd..2aa8cfd 100644 --- a/torch_modules/src/RawInputLSTM.cpp +++ b/torch_modules/src/RawInputLSTM.cpp @@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di if (leftWindow < 0 or rightWindow < 0) return; - for (int i = 0; i < leftWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - for (int i = 0; i <= rightWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & contextElement : context) + { + for (int i = 0; i < leftWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); + else + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + else + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } } -- GitLab