#ifndef DEPTHLAYERTREEEMBEDDING__H #define DEPTHLAYERTREEEMBEDDING__H #include <torch/torch.h> #include "Submodule.hpp" #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" #include "Concat.hpp" class DepthLayerTreeEmbeddingModuleImpl : public Submodule { private : std::vector<int> maxElemPerDepth; std::vector<std::string> columns; std::vector<int> focusedBuffer; std::vector<int> focusedStack; torch::nn::Embedding wordEmbeddings{nullptr}; std::vector<std::shared_ptr<MyModule>> depthModules; int inSize; public : DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition); torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void registerEmbeddings(std::filesystem::path pretrained) override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); #endif