Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#ifndef LSTMNETWORK__H
#define LSTMNETWORK__H
#include "NeuralNetwork.hpp"
class LSTMNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Dropout embeddingsDropout{nullptr};
torch::nn::Dropout lstmDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
torch::nn::LSTM contextLSTM{nullptr};
torch::nn::LSTM rawInputLSTM{nullptr};
std::vector<torch::nn::LSTM> lstms;
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
#endif