#include "LSTM.hpp" LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options)) { auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) .batch_first(std::get<0>(options)) .bidirectional(std::get<1>(options)) .num_layers(std::get<2>(options)) .dropout(std::get<3>(options)); lstm = register_module("lstm", torch::nn::LSTM(lstmOptions)); } torch::Tensor LSTMImpl::forward(torch::Tensor input) { auto lstmOut = std::get<0>(lstm(input)); if (outputAll) return lstmOut.reshape({lstmOut.size(0), -1}); if (lstm->options.bidirectional()) return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1); return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1); } int LSTMImpl::getOutputSize(int sequenceLength) { if (outputAll) return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1); return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1); }