#ifndef NEURALNETWORK__H #define NEURALNETWORK__H #include <torch/torch.h> #include "Config.hpp" #include "Dict.hpp" class NeuralNetworkImpl : public torch::nn::Module { private : int leftBorder{5}; int rightBorder{5}; int nbStackElements{2}; protected : void setRightBorder(int rightBorder); void setLeftBorder(int leftBorder); void setNbStackElements(int nbStackElements); public : virtual std::vector<torch::Tensor> & denseParameters() = 0; virtual std::vector<torch::Tensor> & sparseParameters() = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; std::vector<long> extractContext(Config & config, Dict & dict) const; int getContextSize() const; }; TORCH_MODULE(NeuralNetwork); #endif