#ifndef CNNNETWORK__H #define CNNNETWORK__H #include "NeuralNetwork.hpp" class CNNNetworkImpl : public NeuralNetworkImpl { private : static inline std::vector<long> focusedBufferIndexes{0,1}; static inline std::vector<long> focusedStackIndexes{0,1}; static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; std::vector<torch::nn::Conv2d> CNNs; std::vector<torch::nn::Conv2d> lettersCNNs; public : CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; #endif