From 5e1f4062dbafba142b52c3d68496b2d6bb21283f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 30 Mar 2021 16:01:30 +0200 Subject: [PATCH] Fixed 'outputAll' option for LSTM module --- torch_modules/src/ContextualModule.cpp | 2 +- torch_modules/src/LSTM.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 164de44..c0a717c 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & .bidirectional(std::stoi(subModuleArguments[0])) .num_layers(std::stoi(subModuleArguments[1])) .dropout(std::stof(subModuleArguments[2])) - .complete(std::stoi(subModuleArguments[3])); + .complete(true); for (auto & target : util::split(sm.str(8), ' ')) { diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index d84af46..c394321 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu torch::Tensor LSTMImpl::forward(torch::Tensor input) { - return std::get<0>(lstm(input)); + auto res = std::get<0>(lstm(input)); + + if (outputAll) + return res; + + return torch::cat({torch::narrow(res, 1, 0, 1), torch::narrow(res, 1, res.size(1)-1, 1)}, 1); } int LSTMImpl::getOutputSize(int sequenceLength) -- GitLab