-
Franck Dary authoredFranck Dary authored
CNN.hpp 415 B
#ifndef CNN__H
#define CNN__H
#include <torch/torch.h>
class CNNImpl : public torch::nn::Module
{
private :
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
public :
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize();
};
TORCH_MODULE(CNN);
#endif