#ifndef CNNNETWORK__H #define CNNNETWORK__H #include "NeuralNetwork.hpp" #include "CNN.hpp" class CNNNetworkImpl : public NeuralNetworkImpl { private : static constexpr int maxNbEmbeddings = 50000; int unknownValueThreshold; std::vector<int> focusedBufferIndexes; std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; int leftWindowRawInput; int rightWindowRawInput; int rawInputSize; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout cnnDropout{nullptr}; torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; CNN contextCNN{nullptr}; CNN rawInputCNN{nullptr}; std::vector<CNN> cnns; public : CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<long> extractContext(Config & config, Dict & dict) const override; }; #endif