#ifndef CNNNETWORK__H #define CNNNETWORK__H #include "NeuralNetwork.hpp" class CNNNetworkImpl : public NeuralNetworkImpl { private : static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; private : std::vector<long> focusedBufferIndexes; std::vector<long> focusedStackIndexes; std::vector<std::string> focusedColumns; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; std::vector<torch::nn::Conv2d> CNNs; std::vector<torch::nn::Conv2d> lettersCNNs; public : CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; #endif