-
Franck Dary authoredFranck Dary authored
NeuralNetwork.hpp 856 B
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H
#include <torch/torch.h>
#include <filesystem>
#include "Config.hpp"
#include "NameHolder.hpp"
#include "StateHolder.hpp"
class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
{
public :
static torch::Device device;
private :
std::string state;
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0;
};
TORCH_MODULE(NeuralNetwork);
#endif