#ifndef CONTEXTLSTM__H #define CONTEXTLSTM__H #include <torch/torch.h> #include "Submodule.hpp" #include "LSTM.hpp" class ContextLSTMImpl : public torch::nn::Module, public Submodule { private : LSTM lstm{nullptr}; std::vector<std::string> columns; std::vector<int> bufferContext; std::vector<int> stackContext; int unknownValueThreshold; std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"}; public : ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; }; TORCH_MODULE(ContextLSTM); #endif