Skip to content
Snippets Groups Projects
NeuralNetwork.hpp 725 B
Newer Older
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H

#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"

class NeuralNetworkImpl : public torch::nn::Module
{
Franck Dary's avatar
Franck Dary committed
  protected : 

  int leftBorder{5};
  int rightBorder{5};
Franck Dary's avatar
Franck Dary committed
  int nbStackElements{2};
  std::vector<std::string> columns{"FORM"};
Franck Dary's avatar
Franck Dary committed

  protected :

  void setRightBorder(int rightBorder);
  void setLeftBorder(int leftBorder);
  void setNbStackElements(int nbStackElements);

  public :

  virtual torch::Tensor forward(torch::Tensor input) = 0;
Franck Dary's avatar
Franck Dary committed
  virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
  int getContextSize() const;
  void setColumns(const std::vector<std::string> & columns);
};
TORCH_MODULE(NeuralNetwork);

#endif