DistanceModule.hpp 863 Bytes
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
4
5
6
7
8
9
#ifndef DISTANCEMODULE__H
#define DISTANCEMODULE__H

#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
10
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
11
12
13
14
15

class DistanceModuleImpl : public Submodule
{
  private :

16
  WordEmbeddings wordEmbeddings{nullptr};
Franck Dary's avatar
Franck Dary committed
17
18
19
20
21
22
23
24
25
26
27
28
29
  std::shared_ptr<MyModule> myModule{nullptr};
  std::vector<int> fromBuffer, fromStack;
  std::vector<int> toBuffer, toStack;
  int threshold;
  int inSize;

  public :

  DistanceModuleImpl(std::string name, const std::string & definition);
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
30
  void registerEmbeddings() override;
Franck Dary's avatar
Franck Dary committed
31
32
33
34
35
};
TORCH_MODULE(DistanceModule);

#endif