Commit 5e1f4062 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed 'outputAll' option for LSTM module

parent 1473579c
......@@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
.complete(true);
for (auto & target : util::split(sm.str(8), ' '))
{
......
......@@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu
torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
return std::get<0>(lstm(input));
auto res = std::get<0>(lstm(input));
if (outputAll)
return res;
return torch::cat({torch::narrow(res, 1, 0, 1), torch::narrow(res, 1, res.size(1)-1, 1)}, 1);
}
int LSTMImpl::getOutputSize(int sequenceLength)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment