#ifndef NEURALNETWORK__H #define NEURALNETWORK__H #include <torch/torch.h> #include "Config.hpp" #include "Dict.hpp" class NeuralNetworkImpl : public torch::nn::Module { public : static torch::Device device; private : bool splitUnknown{false}; std::string state; protected : static constexpr int maxNbEmbeddings = 150000; public : virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; bool mustSplitUnknown() const; void setSplitUnknown(bool splitUnknown); void setState(const std::string & state); const std::string & getState() const; }; TORCH_MODULE(NeuralNetwork); #endif