Skip to content
Snippets Groups Projects
NeuralNetwork.hpp 856 B
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H

#include <torch/torch.h>
#include <filesystem>
#include "Config.hpp"
#include "NameHolder.hpp"
#include "StateHolder.hpp"

class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public StateHolder
{
  public :

  static torch::Device device;

  private :

  std::string state;

  public :

  virtual torch::Tensor forward(torch::Tensor input) = 0;
  virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
  virtual void registerEmbeddings() = 0;
  virtual void saveDicts(std::filesystem::path path) = 0;
  virtual void loadDicts(std::filesystem::path path) = 0;
  virtual void setDictsState(Dict::State state) = 0;
  virtual void setCountOcc(bool countOcc) = 0;
  virtual void removeRareDictElements(float rarityThreshold) = 0;
};
TORCH_MODULE(NeuralNetwork);

#endif