ContextualModule.hpp 1.07 KB
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#ifndef CONTEXTUALMODULE__H
#define CONTEXTUALMODULE__H

#include <torch/torch.h>
#include <optional>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"

class ContextualModuleImpl : public Submodule
{
  private :

  torch::nn::Embedding wordEmbeddings{nullptr};
  std::shared_ptr<MyModule> myModule{nullptr};
  std::vector<std::string> columns;
  std::vector<std::function<std::string(const std::string &)>> functions;
  std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
  std::pair<int,int> window;
  int inSize;
  int outSize;
  std::filesystem::path path;
25
  std::filesystem::path w2vFiles;
Franck Dary's avatar
Franck Dary committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39

  public :

  ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
  torch::Tensor forward(torch::Tensor input);
  std::size_t getOutputSize() override;
  std::size_t getInputSize() override;
  void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
  void registerEmbeddings() override;
};
TORCH_MODULE(ContextualModule);

#endif