Skip to content
Snippets Groups Projects
NeuralNetwork.hpp 869 B
Newer Older
  • Learn to ignore specific revisions
  • #ifndef NEURALNETWORK__H
    #define NEURALNETWORK__H
    
    #include <torch/torch.h>
    #include "Config.hpp"
    #include "Dict.hpp"
    
    class NeuralNetworkImpl : public torch::nn::Module
    {
    
      public :
    
      static torch::Device device;
    
    
    Franck Dary's avatar
    Franck Dary committed
      protected : 
    
      unsigned leftBorder{5};
      unsigned rightBorder{5};
      unsigned nbStackElements{2};
    
      std::vector<std::string> columns{"FORM"};
    
    Franck Dary's avatar
    Franck Dary committed
    
      protected :
    
      void setRightBorder(int rightBorder);
      void setLeftBorder(int leftBorder);
      void setNbStackElements(int nbStackElements);
    
    
      public :
    
      virtual torch::Tensor forward(torch::Tensor input) = 0;
    
      virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
    
      std::vector<long> extractContextIndexes(const Config & config) const;
    
      int getContextSize() const;
    
      void setColumns(const std::vector<std::string> & columns);
    
    };
    TORCH_MODULE(NeuralNetwork);
    
    #endif