Skip to content
Snippets Groups Projects
CNN.hpp 415 B
#ifndef CNN__H
#define CNN__H

#include <torch/torch.h>

class CNNImpl : public torch::nn::Module
{
  private :

  std::vector<int> windowSizes;
  std::vector<torch::nn::Conv2d> CNNs;
  int nbFilters;
  int elementSize;

  public :

  CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize();

};
TORCH_MODULE(CNN);

#endif