"reading_machine/git@gitlab.lis-lab.fr:franck.dary/macaon.git" did not exist on "209dd02c5d62330c06a71d9a631369c4b8849fc3"
Newer
Older
#ifndef RLTNETWORK__H
#define RLTNETWORK__H
class RLTNetworkImpl : public NeuralNetworkImpl
static constexpr long maxNbChilds{8};
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM vectorBiLSTM{nullptr};
torch::nn::LSTM treeLSTM{nullptr};
torch::Tensor S;
torch::Tensor nullTree;
RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;