#ifndef DEPTHLAYERTREEEMBEDDING__H #define DEPTHLAYERTREEEMBEDDING__H #include <torch/torch.h> #include "Submodule.hpp" #include "LSTM.hpp" class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule { private : std::vector<int> maxElemPerDepth; std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; std::vector<LSTM> depthLstm; public : DepthLayerTreeEmbeddingImpl(std::vector<int> maxElemPerDepth, int embeddingsSize, int outEmbeddingsSize, std::vector<std::string> columns, std::vector<int> focusedBuffer, std::vector<int> focusedStack, LSTMImpl::LSTMOptions options); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; }; TORCH_MODULE(DepthLayerTreeEmbedding); #endif