WordEmbeddings.cpp 1.02 KB
Newer Older
1
2
3
#include "WordEmbeddings.hpp"

bool WordEmbeddingsImpl::scaleGradByFreq = false;
4
bool WordEmbeddingsImpl::canTrainPretrained = false;
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();

WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
{
  embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
}

torch::nn::Embedding WordEmbeddingsImpl::get()
{
  return embeddings;
}

void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
{
  WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq;
}

void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
{
  WordEmbeddingsImpl::maxNorm = maxNorm;
}

27
28
29
30
31
void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
{
  WordEmbeddingsImpl::canTrainPretrained = value;
}

32
33
34
35
36
torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
  return embeddings(input);
}

37
38
39
40
41
bool WordEmbeddingsImpl::getCanTrainPretrained()
{
  return canTrainPretrained;
}