Skip to content
Snippets Groups Projects
CNNNetwork.hpp 994 B
Newer Older
Franck Dary's avatar
Franck Dary committed
#ifndef CNNNETWORK__H
#define CNNNETWORK__H

#include "NeuralNetwork.hpp"

class CNNNetworkImpl : public NeuralNetworkImpl
{
  private :

  static inline std::vector<long> windowSizes{2,3,4};
  static constexpr unsigned int maxNbLetters = 10;

  private :

  std::vector<long> focusedBufferIndexes;
  std::vector<long> focusedStackIndexes;
  std::vector<std::string> focusedColumns;

Franck Dary's avatar
Franck Dary committed
  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Linear linear1{nullptr};
  torch::nn::Linear linear2{nullptr};
  std::vector<torch::nn::Conv2d> CNNs;
  std::vector<torch::nn::Conv2d> lettersCNNs;

  public :

  CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<long> focusedBufferIndexes, std::vector<long> focusedStackIndexes, std::vector<std::string> focusedColumns);
Franck Dary's avatar
Franck Dary committed
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<long> extractContext(Config & config, Dict & dict) const override;
};

#endif