WordEmbeddings.hpp 547 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#ifndef WORDEMBEDDINGS__H
#define WORDEMBEDDINGS__H

#include "torch/torch.h"

class WordEmbeddingsImpl : public torch::nn::Module
{
  private :

  static bool scaleGradByFreq;
  static float maxNorm;

  private :
  
  torch::nn::Embedding embeddings{nullptr};

  public :

  static void setScaleGradByFreq(bool scaleGradByFreq);
  static void setMaxNorm(float maxNorm);

  WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
  torch::nn::Embedding get();
  torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(WordEmbeddings);

#endif