Skip to content
Snippets Groups Projects
LSTM.cpp 1.02 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "LSTM.hpp"
    
    
    LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options))
    
    Franck Dary's avatar
    Franck Dary committed
    {
      auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize)
        .batch_first(std::get<0>(options))
        .bidirectional(std::get<1>(options))
    
        .num_layers(std::get<2>(options))
    
    Franck Dary's avatar
    Franck Dary committed
        .dropout(std::get<3>(options));
    
      lstm = register_module("lstm", torch::nn::LSTM(lstmOptions));
    }
    
    torch::Tensor LSTMImpl::forward(torch::Tensor input)
    {
    
      auto lstmOut = std::get<0>(lstm(input));
    
    Franck Dary's avatar
    Franck Dary committed
    
      if (outputAll)
        return lstmOut.reshape({lstmOut.size(0), -1});
    
      if (lstm->options.bidirectional())
        return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1);
    
      return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1);
    }
    
    int LSTMImpl::getOutputSize(int sequenceLength)
    {
      if (outputAll)
        return sequenceLength * lstm->options.hidden_size() * (lstm->options.bidirectional() ? 2 : 1);
    
      return lstm->options.hidden_size() * (lstm->options.bidirectional() ? 4 : 1);
    }