#ifndef LSTMNETWORK__H #define LSTMNETWORK__H #include "NeuralNetwork.hpp" class LSTMNetworkImpl : public NeuralNetworkImpl { private : static constexpr int maxNbEmbeddings = 50000; int unknownValueThreshold; std::vector<int> focusedBufferIndexes; std::vector<int> focusedStackIndexes; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; int leftWindowRawInput; int rightWindowRawInput; int rawInputSize; torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout lstmDropout{nullptr}; torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; torch::nn::LSTM contextLSTM{nullptr}; torch::nn::LSTM rawInputLSTM{nullptr}; std::vector<torch::nn::LSTM> lstms; public : LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; #endif