-
Franck Dary authoredFranck Dary authored
CNNNetwork.hpp 1010 B
#ifndef CNNNETWORK__H
#define CNNNETWORK__H
#include "NeuralNetwork.hpp"
#include "CNN.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput{5};
int rightWindowRawInput{5};
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
CNN contextCNN{nullptr};
CNN rawInputCNN{nullptr};
std::vector<CNN> cnns;
public :
CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements);
torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override;
};
#endif