#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H

#include <torch/torch.h>
#include <filesystem>
#include "Config.hpp"
#include "NameHolder.hpp"

class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
{
  private :

  static torch::Device device;
  std::map<std::string, torch::Tensor> lossParameters;

  public :

  torch::Tensor getLossParameter(std::string state);
  virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
  virtual torch::Tensor extractContext(Config & config) = 0;
  virtual void registerEmbeddings(bool loadPretrained) = 0;
  virtual void saveDicts(std::filesystem::path path) = 0;
  virtual void loadDicts(std::filesystem::path path) = 0;
  virtual void setDictsState(Dict::State state) = 0;
  virtual void setCountOcc(bool countOcc) = 0;
  virtual void removeRareDictElements(float rarityThreshold) = 0;

  static torch::Device getPreferredDevice();
  static torch::Device getDevice();
  static void setDevice(torch::Device device);
  static float entropy(torch::Tensor probabilities);
};
TORCH_MODULE(NeuralNetwork);

#endif