Skip to content
Snippets Groups Projects
GRU.hpp 373 B
#ifndef GRU__H
#define GRU__H

#include <torch/torch.h>
#include "MyModule.hpp"

class GRUImpl : public MyModule
{
  private :

  torch::nn::GRU gru{nullptr};
  bool outputAll;

  public :

  GRUImpl(int inputSize, int outputSize, ModuleOptions options);
  torch::Tensor forward(torch::Tensor input);
  int getOutputSize(int sequenceLength);
};
TORCH_MODULE(GRU);

#endif