Select Git revision
SplitTransModule.cpp
-
Franck Dary authoredFranck Dary authored
CNN.hpp 419 B
#ifndef CNN__H
#define CNN__H
#include <torch/torch.h>
#include "MyModule.hpp"
class CNNImpl : public MyModule
{
private :
std::vector<torch::nn::Conv2d> CNNs;
std::vector<int> windowSizes{2, 3};
int outputSize;
public :
CNNImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(CNN);
#endif