Skip to content
Snippets Groups Projects
Commit 5e1f4062 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed 'outputAll' option for LSTM module

parent 1473579c
No related branches found
No related tags found
No related merge requests found
...@@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & ...@@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
.bidirectional(std::stoi(subModuleArguments[0])) .bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1])) .num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2])) .dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3])); .complete(true);
for (auto & target : util::split(sm.str(8), ' ')) for (auto & target : util::split(sm.str(8), ' '))
{ {
......
...@@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu ...@@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu
torch::Tensor LSTMImpl::forward(torch::Tensor input) torch::Tensor LSTMImpl::forward(torch::Tensor input)
{ {
return std::get<0>(lstm(input)); auto res = std::get<0>(lstm(input));
if (outputAll)
return res;
return torch::cat({torch::narrow(res, 1, 0, 1), torch::narrow(res, 1, res.size(1)-1, 1)}, 1);
} }
int LSTMImpl::getOutputSize(int sequenceLength) int LSTMImpl::getOutputSize(int sequenceLength)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment