#ifndef NEURALNETWORK__H #define NEURALNETWORK__H #include <torch/torch.h> #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" #include "StateHolder.hpp" class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder { public : static torch::Device device; private : std::string state; public : virtual torch::Tensor forward(torch::Tensor input) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; virtual void setCountOcc(bool countOcc) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0; }; TORCH_MODULE(NeuralNetwork); #endif