Skip to content
Snippets Groups Projects
DepthLayerTreeEmbeddingModule.hpp 984 B
#ifndef DEPTHLAYERTREEEMBEDDING__H
#define DEPTHLAYERTREEEMBEDDING__H

#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"

class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{
  private :

  std::vector<int> maxElemPerDepth;
  std::vector<std::string> columns;
  std::vector<int> focusedBuffer;
  std::vector<int> focusedStack;
  torch::nn::Embedding wordEmbeddings{nullptr};
  std::vector<std::shared_ptr<MyModule>> depthModules;
  int inSize;

  public :

  DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
  void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);

#endif