#ifndef CNNNETWORK__H #define CNNNETWORK__H #include "NeuralNetwork.hpp" #include "CNN.hpp" class CNNNetworkImpl : public NeuralNetworkImpl { private : static constexpr unsigned int maxNbLetters = 10; private : std::vector<long> focusedBufferIndexes; std::vector<long> focusedStackIndexes; std::vector<std::string> focusedColumns; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; CNN lettersCNN{nullptr}; public : CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; #endif