Skip to content
Snippets Groups Projects
RLTNetwork.hpp 841 B
Newer Older
#ifndef RLTNETWORK__H
#define RLTNETWORK__H
Franck Dary's avatar
Franck Dary committed

#include "NeuralNetwork.hpp"

class RLTNetworkImpl : public NeuralNetworkImpl
Franck Dary's avatar
Franck Dary committed
{
  private :

Franck Dary's avatar
Franck Dary committed
  static constexpr long maxNbChilds{8};
  static inline std::vector<long> focusedBufferIndexes{0,1,2};
  static inline std::vector<long> focusedStackIndexes{0,1};

Franck Dary's avatar
Franck Dary committed
  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Linear linear1{nullptr};
  torch::nn::Linear linear2{nullptr};
Franck Dary's avatar
Franck Dary committed
  torch::nn::LSTM vectorBiLSTM{nullptr};
  torch::nn::LSTM treeLSTM{nullptr};
  torch::Tensor S;
  torch::Tensor nullTree;
Franck Dary's avatar
Franck Dary committed

  public :

  RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
Franck Dary's avatar
Franck Dary committed
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
Franck Dary's avatar
Franck Dary committed
};

#endif