From 5e1f4062dbafba142b52c3d68496b2d6bb21283f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 30 Mar 2021 16:01:30 +0200
Subject: [PATCH] Fixed 'outputAll' option for LSTM module

---
 torch_modules/src/ContextualModule.cpp | 2 +-
 torch_modules/src/LSTM.cpp             | 7 ++++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 164de44..c0a717c 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
               .bidirectional(std::stoi(subModuleArguments[0]))
               .num_layers(std::stoi(subModuleArguments[1]))
               .dropout(std::stof(subModuleArguments[2]))
-              .complete(std::stoi(subModuleArguments[3]));
+              .complete(true);
 
             for (auto & target : util::split(sm.str(8), ' '))
             {
diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp
index d84af46..c394321 100644
--- a/torch_modules/src/LSTM.cpp
+++ b/torch_modules/src/LSTM.cpp
@@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu
 
 torch::Tensor LSTMImpl::forward(torch::Tensor input)
 {
-  return std::get<0>(lstm(input));
+  auto res = std::get<0>(lstm(input));
+
+  if (outputAll)
+    return res;
+
+  return torch::cat({torch::narrow(res, 1, 0, 1), torch::narrow(res, 1, res.size(1)-1, 1)}, 1);
 }
 
 int LSTMImpl::getOutputSize(int sequenceLength)
-- 
GitLab