Skip to content
Snippets Groups Projects
Select Git revision
  • d2655230d8b685911d3e3b7ffe0dceab409a351b
  • master default protected
  • loss
  • producer
4 results

GRU.hpp

Blame
  • GRU.hpp 373 B
    #ifndef GRU__H
    #define GRU__H
    
    #include <torch/torch.h>
    #include "MyModule.hpp"
    
    class GRUImpl : public MyModule
    {
      private :
    
      torch::nn::GRU gru{nullptr};
      bool outputAll;
    
      public :
    
      GRUImpl(int inputSize, int outputSize, ModuleOptions options);
      torch::Tensor forward(torch::Tensor input);
      int getOutputSize(int sequenceLength);
    };
    TORCH_MODULE(GRU);
    
    #endif