Skip to content
Snippets Groups Projects
LSTMNetwork.hpp 1.32 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#ifndef LSTMNETWORK__H
#define LSTMNETWORK__H

#include "NeuralNetwork.hpp"

class LSTMNetworkImpl : public NeuralNetworkImpl
{
  private :

  static constexpr int maxNbEmbeddings = 50000;

  int unknownValueThreshold;
  std::vector<int> focusedBufferIndexes;
  std::vector<int> focusedStackIndexes;
  std::vector<std::string> focusedColumns;
  std::vector<int> maxNbElements;
  int leftWindowRawInput;
  int rightWindowRawInput;
  int rawInputSize;

  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Dropout embeddingsDropout{nullptr};
  torch::nn::Dropout lstmDropout{nullptr};
  torch::nn::Dropout hiddenDropout{nullptr};
  torch::nn::Linear linear1{nullptr};
  torch::nn::Linear linear2{nullptr};
  torch::nn::LSTM contextLSTM{nullptr};
  torch::nn::LSTM rawInputLSTM{nullptr};
  std::vector<torch::nn::LSTM> lstms;

  public :

  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};

#endif