#include "WordEmbeddings.hpp" bool WordEmbeddingsImpl::scaleGradByFreq = false; float WordEmbeddingsImpl::maxNorm = std::numeric_limits::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) { embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); } torch::nn::Embedding WordEmbeddingsImpl::get() { return embeddings; } void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq) { WordEmbeddingsImpl::scaleGradByFreq = scaleGradByFreq; } void WordEmbeddingsImpl::setMaxNorm(float maxNorm) { WordEmbeddingsImpl::maxNorm = maxNorm; } torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) { return embeddings(input); }