Newer
Older
#ifndef NEURALNETWORK__H
#define NEURALNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
#include "Dict.hpp"
class NeuralNetworkImpl : public torch::nn::Module
{
public :
static torch::Device device;
std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{};
std::vector<int> bufferFocused{};
std::vector<int> stackFocused{};
void setBufferContext(const std::vector<int> & bufferContext);
void setStackContext(const std::vector<int> & stackContext);
void setBufferFocused(const std::vector<int> & bufferFocused);
void setStackFocused(const std::vector<int> & stackFocused);
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
std::vector<long> extractContextIndexes(const Config & config) const;
std::vector<long> extractFocusedIndexes(const Config & config) const;
void setColumns(const std::vector<std::string> & columns);
};
TORCH_MODULE(NeuralNetwork);
#endif