Skip to content
Snippets Groups Projects
NeuralNetwork.hpp 1.11 KiB
Newer Older
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H

#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"

class NeuralNetworkImpl : public torch::nn::Module
{
  public :

  static torch::Device device;

Franck Dary's avatar
Franck Dary committed
  protected : 
  std::vector<std::string> columns{"FORM"};
  std::vector<int> bufferContext{-3,-2,-1,0,1};
  std::vector<int> stackContext{};
  std::vector<int> bufferFocused{};
  std::vector<int> stackFocused{};
Franck Dary's avatar
Franck Dary committed

  protected :

  void setBufferContext(const std::vector<int> & bufferContext);
  void setStackContext(const std::vector<int> & stackContext);
  void setBufferFocused(const std::vector<int> & bufferFocused);
  void setStackFocused(const std::vector<int> & stackFocused);

  public :

  virtual torch::Tensor forward(torch::Tensor input) = 0;
  virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
  std::vector<long> extractContextIndexes(const Config & config) const;
  std::vector<long> extractFocusedIndexes(const Config & config) const;
  int getContextSize() const;
  void setColumns(const std::vector<std::string> & columns);
};
TORCH_MODULE(NeuralNetwork);

#endif