Skip to content
Snippets Groups Projects
CNN.hpp 431 B
Newer Older
Franck Dary's avatar
Franck Dary committed
#ifndef CNN__H
#define CNN__H

#include <torch/torch.h>
#include "fmt/core.h"

class CNNImpl : public torch::nn::Module
{
  private :

  std::vector<long> windowSizes;
  std::vector<torch::nn::Conv2d> CNNs;
  int nbFilters;
  int elementSize;

  public :

  CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize);
  torch::Tensor forward(torch::Tensor input);
  int getOutputSize();

};
TORCH_MODULE(CNN);

#endif