Skip to content
Snippets Groups Projects
NeuralNetwork.hpp 662 B
Newer Older
  • Learn to ignore specific revisions
  • #ifndef NEURALNETWORK__H
    #define NEURALNETWORK__H
    
    #include <torch/torch.h>
    #include "Config.hpp"
    #include "Dict.hpp"
    
    class NeuralNetworkImpl : public torch::nn::Module
    {
      private : 
    
      int leftBorder{5};
      int rightBorder{5};
    
    Franck Dary's avatar
    Franck Dary committed
      int nbStackElements{2};
    
      std::vector<std::string> columns{"FORM", "UPOS"};
    
    Franck Dary's avatar
    Franck Dary committed
    
      protected :
    
      void setRightBorder(int rightBorder);
      void setLeftBorder(int leftBorder);
      void setNbStackElements(int nbStackElements);
    
    
      public :
    
      virtual torch::Tensor forward(torch::Tensor input) = 0;
      std::vector<long> extractContext(Config & config, Dict & dict) const;
      int getContextSize() const;
    };
    TORCH_MODULE(NeuralNetwork);
    
    #endif