DepthLayerTreeEmbeddingModule.hpp 976 Bytes
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
4
#ifndef DEPTHLAYERTREEEMBEDDING__H
#define DEPTHLAYERTREEEMBEDDING__H

#include <torch/torch.h>
5
#include "Submodule.hpp"
6
#include "MyModule.hpp"
Franck Dary's avatar
Franck Dary committed
7
#include "LSTM.hpp"
8
#include "GRU.hpp"
Franck Dary's avatar
Franck Dary committed
9
#include "Concat.hpp"
10
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
11

Franck Dary's avatar
Franck Dary committed
12
class DepthLayerTreeEmbeddingModuleImpl : public Submodule
Franck Dary's avatar
Franck Dary committed
13
14
15
{
  private :

16
17
18
19
  std::vector<int> maxElemPerDepth;
  std::vector<std::string> columns;
  std::vector<int> focusedBuffer;
  std::vector<int> focusedStack;
20
  WordEmbeddings wordEmbeddings{nullptr};
21
  std::vector<std::shared_ptr<MyModule>> depthModules;
22
  int inSize;
Franck Dary's avatar
Franck Dary committed
23
24
25

  public :

Franck Dary's avatar
Franck Dary committed
26
  DepthLayerTreeEmbeddingModuleImpl(std::string name, const std::string & definition);
Franck Dary's avatar
Franck Dary committed
27
  torch::Tensor forward(torch::Tensor input);
28
29
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
Franck Dary's avatar
Franck Dary committed
30
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
31
  void registerEmbeddings() override;
Franck Dary's avatar
Franck Dary committed
32
};
Franck Dary's avatar
Franck Dary committed
33
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
Franck Dary's avatar
Franck Dary committed
34
35
36

#endif