Skip to content
Snippets Groups Projects
CNNNetwork.hpp 761 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #ifndef CNNNETWORK__H
    #define CNNNETWORK__H
    
    #include "NeuralNetwork.hpp"
    
    class CNNNetworkImpl : public NeuralNetworkImpl
    {
      private :
    
      static inline std::vector<long> focusedBufferIndexes{0,1};
      static inline std::vector<long> windowSizes{2,3,4};
      static constexpr unsigned int maxNbLetters = 10;
    
      torch::nn::Embedding wordEmbeddings{nullptr};
      torch::nn::Linear linear1{nullptr};
      torch::nn::Linear linear2{nullptr};
      std::vector<torch::nn::Conv2d> CNNs;
      std::vector<torch::nn::Conv2d> lettersCNNs;
    
      public :
    
      CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
      torch::Tensor forward(torch::Tensor input) override;
      std::vector<long> extractContext(Config & config, Dict & dict) const override;
    };
    
    #endif