Skip to content
Snippets Groups Projects
OneWordNetwork.hpp 604 B
Newer Older
  • Learn to ignore specific revisions
  • #ifndef ONEWORDNETWORK__H
    #define ONEWORDNETWORK__H
    
    #include "NeuralNetwork.hpp"
    
    class OneWordNetworkImpl : public NeuralNetworkImpl
    {
      private :
    
      torch::nn::Embedding wordEmbeddings{nullptr};
      torch::nn::Linear linear{nullptr};
      int focusedIndex;
    
      std::vector<torch::Tensor> _denseParameters;
      std::vector<torch::Tensor> _sparseParameters;
    
      public :
    
      OneWordNetworkImpl(int nbOutputs, int focusedIndex);
      torch::Tensor forward(torch::Tensor input) override;
      std::vector<torch::Tensor> & denseParameters() override;
      std::vector<torch::Tensor> & sparseParameters() override;
    };
    
    #endif