Skip to content
Snippets Groups Projects
Select Git revision
  • 0b09a4534c764866e71a009b552a1288e0c790b0
  • master default protected
  • v1.1
  • operations
  • v1.1.2
  • v1.1.1
  • v1.1.0
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.0
  • v1.0.1
12 results

test_madarray.py

Blame
  • CNN.hpp 419 B
    #ifndef CNN__H
    #define CNN__H
    
    #include <torch/torch.h>
    #include "MyModule.hpp"
    
    class CNNImpl : public MyModule
    {
      private :
    
      std::vector<torch::nn::Conv2d> CNNs;
      std::vector<int> windowSizes{2, 3};
      int outputSize;
    
      public :
    
      CNNImpl(int inputSize, int outputSize, ModuleOptions options);
      torch::Tensor forward(torch::Tensor input);
      int getOutputSize(int sequenceLength);
    };
    TORCH_MODULE(CNN);
    
    #endif