-
Franck Dary authoredFranck Dary authored
ContextLSTM.hpp 927 B
#ifndef CONTEXTLSTM__H
#define CONTEXTLSTM__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "LSTM.hpp"
class ContextLSTMImpl : public torch::nn::Module, public Submodule
{
private :
LSTM lstm{nullptr};
std::vector<std::string> columns;
std::vector<int> bufferContext;
std::vector<int> stackContext;
int unknownValueThreshold;
std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"};
public :
ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
};
TORCH_MODULE(ContextLSTM);
#endif