#ifndef NEURALNETWORK__H #define NEURALNETWORK__H #include <torch/torch.h> #include <filesystem> #include "Config.hpp" #include "NameHolder.hpp" class NeuralNetworkImpl : public torch::nn::Module, public NameHolder { private : static torch::Device device; std::map<std::string, torch::Tensor> lossParameters; public : torch::Tensor getLossParameter(std::string state); virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual torch::Tensor extractContext(Config & config) = 0; virtual void registerEmbeddings(bool loadPretrained) = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; virtual void setDictsState(Dict::State state) = 0; virtual void setCountOcc(bool countOcc) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0; static torch::Device getPreferredDevice(); static torch::Device getDevice(); static void setDevice(torch::Device device); static float entropy(torch::Tensor probabilities); }; TORCH_MODULE(NeuralNetwork); #endif