#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H

#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"

class NeuralNetworkImpl : public torch::nn::Module
{
  public :

  static torch::Device device;

  private :

  bool splitUnknown{false};
  std::string state;

  protected : 

  static constexpr int maxNbEmbeddings = 150000;

  public :

  virtual torch::Tensor forward(torch::Tensor input) = 0;
  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
  bool mustSplitUnknown() const;
  void setSplitUnknown(bool splitUnknown);
  void setState(const std::string & state);
  const std::string & getState() const;
};
TORCH_MODULE(NeuralNetwork);

#endif