Newer
Older
#ifndef CNNNETWORK__H
#define CNNNETWORK__H
#include "NeuralNetwork.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
static inline std::vector<long> focusedBufferIndexes{0,1};
static inline std::vector<long> focusedStackIndexes{0,1};
static inline std::vector<long> windowSizes{2,3,4};
static constexpr unsigned int maxNbLetters = 10;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
std::vector<torch::nn::Conv2d> CNNs;
std::vector<torch::nn::Conv2d> lettersCNNs;
public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
#endif