#include "LSTM.hpp"

LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
{
  auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
    .batch_first(std::get<0>(options))
    .bidirectional(std::get<1>(options))
    .num_layers(std::get<2>(options))
    .dropout(std::get<3>(options));

  lstm = register_module("lstm", torch::nn::LSTM(lstmOptions));
}

torch::Tensor LSTMImpl::forward(torch::Tensor input)
{
  auto lstmOut = std::get<0>(lstm(input));

  if (outputAll)
    return lstmOut.reshape({lstmOut.size(0), -1});

  if (lstm->options.bidirectional())
    return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);

  return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
}

int LSTMImpl::getOutputSize(int sequenceLength)
{
  if (outputAll)
    return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1);

  return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1);
}