Newer
Older
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"
class NeuralNetworkImpl : public torch::nn::Module
{
private :
int leftBorder{5};
int rightBorder{5};
int nbStackElements{2};
protected :
void setRightBorder(int rightBorder);
void setLeftBorder(int leftBorder);
void setNbStackElements(int nbStackElements);
public :
virtual std::vector<torch::Tensor> & denseParameters() = 0;
virtual std::vector<torch::Tensor> & sparseParameters() = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const;
};
TORCH_MODULE(NeuralNetwork);
#endif