FocusedColumnModule.hpp 1.04 KB
Newer Older
1
2
#ifndef FOCUSEDCOLUMNMODULE__H
#define FOCUSEDCOLUMNMODULE__H
Franck Dary's avatar
Franck Dary committed
3
4
5

#include <torch/torch.h>
#include "Submodule.hpp"
6
#include "MyModule.hpp"
Franck Dary's avatar
Franck Dary committed
7
#include "LSTM.hpp"
8
#include "GRU.hpp"
Franck Dary's avatar
Franck Dary committed
9
#include "Concat.hpp"
10
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
11

Franck Dary's avatar
Franck Dary committed
12
class FocusedColumnModuleImpl : public Submodule
Franck Dary's avatar
Franck Dary committed
13
14
15
{
  private :

16
  WordEmbeddings wordEmbeddings{nullptr};
17
  std::shared_ptr<MyModule> myModule{nullptr};
Franck Dary's avatar
Franck Dary committed
18
19
  std::vector<int> focusedBuffer, focusedStack;
  std::string column;
20
  std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
Franck Dary's avatar
Franck Dary committed
21
  int maxNbElements;
22
  int inSize;
23
24
  std::filesystem::path path;
  std::filesystem::path w2vFiles;
Franck Dary's avatar
Franck Dary committed
25
26
27

  public :

28
  FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
Franck Dary's avatar
Franck Dary committed
29
30
31
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
Franck Dary's avatar
Franck Dary committed
32
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
33
  void registerEmbeddings() override;
Franck Dary's avatar
Franck Dary committed
34
};
Franck Dary's avatar
Franck Dary committed
35
TORCH_MODULE(FocusedColumnModule);
Franck Dary's avatar
Franck Dary committed
36
37
38

#endif