WordEmbeddings.cpp 790 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "WordEmbeddings.hpp"

bool WordEmbeddingsImpl::scaleGradByFreq = false;
float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();

WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
{
  embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
}

torch::nn::Embedding WordEmbeddingsImpl::get()
{
  return embeddings;
}

void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
{
  WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq;
}

void WordEmbeddingsImpl::setMaxNorm(float maxNorm)
{
  WordEmbeddingsImpl::maxNorm = maxNorm;
}

torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
{
  return embeddings(input);
}