Submodule.hpp 875 Bytes
Newer Older
Franck Dary's avatar
Franck Dary committed
1
2
3
#ifndef SUBMODULE__H
#define SUBMODULE__H

4
#include <torch/torch.h>
5
#include <filesystem>
Franck Dary's avatar
Franck Dary committed
6
#include "Config.hpp"
Franck Dary's avatar
Franck Dary committed
7
#include "DictHolder.hpp"
Franck Dary's avatar
Franck Dary committed
8
#include "StateHolder.hpp"
Franck Dary's avatar
Franck Dary committed
9

Franck Dary's avatar
Franck Dary committed
10
class Submodule : public torch::nn::Module, public DictHolder, public StateHolder
Franck Dary's avatar
Franck Dary committed
11
12
13
14
15
16
17
18
{
  protected :

  std::size_t firstInputIndex{0};

  public :

  void setFirstInputIndex(std::size_t firstInputIndex);
19
  void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
Franck Dary's avatar
Franck Dary committed
20
21
  virtual std::size_t getOutputSize() = 0;
  virtual std::size_t getInputSize() = 0;
Franck Dary's avatar
Franck Dary committed
22
  virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
23
  virtual torch::Tensor forward(torch::Tensor input) = 0;
24
  virtual void registerEmbeddings() = 0;
25
  std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
Franck Dary's avatar
Franck Dary committed
26
27
28
29
};

#endif