Skip to content
Snippets Groups Projects
RLTNetwork.hpp 796 B
Newer Older
  • Learn to ignore specific revisions
  • #ifndef RLTNETWORK__H
    #define RLTNETWORK__H
    
    Franck Dary's avatar
    Franck Dary committed
    
    #include "NeuralNetwork.hpp"
    
    
    class RLTNetworkImpl : public NeuralNetworkImpl
    
    Franck Dary's avatar
    Franck Dary committed
    {
      private :
    
    
    Franck Dary's avatar
    Franck Dary committed
      static constexpr long maxNbChilds{8};
      static inline std::vector<long> focusedBufferIndexes{0,1,2};
      static inline std::vector<long> focusedStackIndexes{0,1};
    
    
    Franck Dary's avatar
    Franck Dary committed
      torch::nn::Embedding wordEmbeddings{nullptr};
      torch::nn::Linear linear1{nullptr};
      torch::nn::Linear linear2{nullptr};
    
    Franck Dary's avatar
    Franck Dary committed
      torch::nn::LSTM vectorBiLSTM{nullptr};
      torch::nn::LSTM treeLSTM{nullptr};
      torch::Tensor S;
      torch::Tensor nullTree;
    
    Franck Dary's avatar
    Franck Dary committed
    
      public :
    
    
      RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
    
    Franck Dary's avatar
    Franck Dary committed
      torch::Tensor forward(torch::Tensor input) override;
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<long> extractContext(Config & config, Dict & dict) const override;
    
    Franck Dary's avatar
    Franck Dary committed
    };
    
    #endif