Newer
Older
#ifndef RTLSTMNETWORK__H
#define RTLSTMNETWORK__H
#include "NeuralNetwork.hpp"
class RTLSTMNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::Dropout dropout{nullptr};
torch::nn::LSTM lstm{nullptr};
public :
RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
};
#endif