Newer
Older
#ifndef CNNNETWORK__H
#define CNNNETWORK__H
#include "NeuralNetwork.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
Franck Dary
committed
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<int> maxNbElements;
int leftWindowRawInput;
int rightWindowRawInput;
int rawInputSize;
torch::nn::Dropout cnnDropout{nullptr};
torch::nn::Dropout hiddenDropout{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;