Newer
Older
LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
.batch_first(std::get<0>(options))
.bidirectional(std::get<1>(options))
.dropout(std::get<3>(options));
lstm = register_module("lstm", torch::nn::LSTM(lstmOptions));
}
torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
if (outputAll)
return lstmOut.reshape({lstmOut.size(0), -1});
if (lstm->options.bidirectional())
return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);
return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
}
int LSTMImpl::getOutputSize(int sequenceLength)
{
if (outputAll)
return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1);
return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1);
}