Skip to content
Snippets Groups Projects
Select Git revision
  • ed05ee4ac47f827f31bd8de5b760c772106be90f
  • master default protected
  • loss
  • producer
4 results

DepthLayerTreeEmbeddingModule.hpp

Blame
  • DepthLayerTreeEmbeddingModule.hpp 984 B
    #ifndef DEPTHLAYERTREEEMBEDDING__H
    #define DEPTHLAYERTREEEMBEDDING__H
    
    #include <torch/torch.h>
    #include "Submodule.hpp"
    #include "MyModule.hpp"
    #include "LSTM.hpp"
    #include "GRU.hpp"
    #include "Concat.hpp"
    
    class DepthLayerTreeEmbeddingModuleImpl : public Submodule
    {
      private :
    
      std::vector<int> maxElemPerDepth;
      std::vector<std::string> columns;
      std::vector<int> focusedBuffer;
      std::vector<int> focusedStack;
      torch::nn::Embedding wordEmbeddings{nullptr};
      std::vector<std::shared_ptr<MyModule>> depthModules;
      int inSize;
    
      public :
    
      DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
      torch::Tensor forward(torch::Tensor input);
      std::size_t getOutputSize() override;
      std::size_t getInputSize() override;
      void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
      void registerEmbeddings(std::filesystem::path pretrained) override;
    };
    TORCH_MODULE(DepthLayerTreeEmbeddingModule);
    
    #endif