diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index f0b58dc099d512e73ce95c9d9c808cf206006cc5..7a6576cc58c586481ac4e4444e73a225934be78f 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; - torch::nn::Dropout lstmDropout{nullptr}; MLP mlp{nullptr}; ContextLSTM contextLSTM{nullptr}; diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index 52a4c235d5489165741ab421995ce9055401e67b..b886245913d33d6ed5e138fddc57d2ca22a6207e 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input) { torch::Tensor output = input; for (unsigned int i = 0; i < layers.size()-1; i++) - output = torch::relu(dropouts[i](layers[i](output))); + output = dropouts[i](torch::relu(layers[i](output))); return layers.back()(output); } diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp index ebcfbfd9e6a6077eb280ff3d7e84c4fc7129fc8d..2aa8cfd9c2f97c9fd99d6ed08d3e2a0aa75c35d8 100644 --- a/torch_modules/src/RawInputLSTM.cpp +++ b/torch_modules/src/RawInputLSTM.cpp @@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di if (leftWindow < 0 or rightWindow < 0) return; - for (int i = 0; i < leftWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - for (int i = 0; i <= rightWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & contextElement : context) + { + for (int i = 0; i < leftWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); + else + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + else + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } }