-
Franck Dary authoredFranck Dary authored
GRU.hpp 373 B
#ifndef GRU__H
#define GRU__H
#include <torch/torch.h>
#include "MyModule.hpp"
class GRUImpl : public MyModule
{
private :
torch::nn::GRU gru{nullptr};
bool outputAll;
public :
GRUImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(GRU);
#endif