Newer
Older
#ifndef CNN__H
#define CNN__H
#include <torch/torch.h>
#include "fmt/core.h"
class CNNImpl : public torch::nn::Module
{
private :
std::vector<long> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
public :
CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
};
TORCH_MODULE(CNN);
#endif