Skip to content
Snippets Groups Projects
Select Git revision
  • 39da3e9abc1e5d9da3dadeded4174dbbc63dfefc
  • main default protected
2 results

menu.py

Blame
  • CNN.hpp 419 B
    #ifndef CNN__H
    #define CNN__H
    
    #include <torch/torch.h>
    #include "MyModule.hpp"
    
    class CNNImpl : public MyModule
    {
      private :
    
      std::vector<torch::nn::Conv2d> CNNs;
      std::vector<int> windowSizes{2, 3};
      int outputSize;
    
      public :
    
      CNNImpl(int inputSize, int outputSize, ModuleOptions options);
      torch::Tensor forward(torch::Tensor input);
      int getOutputSize(int sequenceLength);
    };
    TORCH_MODULE(CNN);
    
    #endif