Newer
Older
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"
class NeuralNetworkImpl : public torch::nn::Module
{
int leftBorder{5};
int rightBorder{5};
std::vector<std::string> columns{"FORM"};
protected :
void setRightBorder(int rightBorder);
void setLeftBorder(int leftBorder);
void setNbStackElements(int nbStackElements);
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<long> extractContext(Config & config, Dict & dict) const;
void setColumns(const std::vector<std::string> & columns);
};
TORCH_MODULE(NeuralNetwork);
#endif