#ifndef RLTNETWORK__H #define RLTNETWORK__H #include "NeuralNetwork.hpp" class RLTNetworkImpl : public NeuralNetworkImpl { private : static constexpr long maxNbChilds{8}; static inline std::vector<long> focusedBufferIndexes{0,1,2}; static inline std::vector<long> focusedStackIndexes{0,1}; int leftBorder, rightBorder; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; torch::nn::LSTM vectorBiLSTM{nullptr}; torch::nn::LSTM treeLSTM{nullptr}; torch::Tensor S; torch::Tensor nullTree; public : RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; #endif