ContextModule.hpp 1.06 KB
Newer Older
1
2
#ifndef CONTEXTMODULE__H
#define CONTEXTMODULE__H
Franck Dary's avatar
Franck Dary committed
3
4

#include <torch/torch.h>
5
#include <optional>
Franck Dary's avatar
Franck Dary committed
6
#include "Submodule.hpp"
7
8
#include "MyModule.hpp"
#include "GRU.hpp"
Franck Dary's avatar
Franck Dary committed
9
#include "LSTM.hpp"
Franck Dary's avatar
Franck Dary committed
10
#include "Concat.hpp"
Franck Dary's avatar
Franck Dary committed
11
#include "Transformer.hpp"
12
#include "WordEmbeddings.hpp"
Franck Dary's avatar
Franck Dary committed
13

14
class ContextModuleImpl : public Submodule
Franck Dary's avatar
Franck Dary committed
15
16
17
{
  private :

18
  WordEmbeddings wordEmbeddings{nullptr};
19
  std::shared_ptr<MyModule> myModule{nullptr};
Franck Dary's avatar
Franck Dary committed
20
  std::vector<std::string> columns;
21
  std::vector<std::function<std::string(const std::string &)>> functions;
22
  std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
23
  int inSize;
24
  std::filesystem::path path;
25
  std::filesystem::path w2vFiles;
Franck Dary's avatar
Franck Dary committed
26
27
28

  public :

29
  ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
Franck Dary's avatar
Franck Dary committed
30
31
32
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
Franck Dary's avatar
Franck Dary committed
33
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
34
  void registerEmbeddings() override;
Franck Dary's avatar
Franck Dary committed
35
};
36
TORCH_MODULE(ContextModule);
Franck Dary's avatar
Franck Dary committed
37
38
39

#endif