Skip to content
Snippets Groups Projects
Commit 7e1a6789 authored by Franck Dary's avatar Franck Dary
Browse files

When constructing WordEmbedding, only use max_norm when necessary

parent 8dfbc696
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); ...@@ -6,7 +6,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
{ {
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); if (maxNorm == std::numeric_limits<float>::max())
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
else
embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
} }
torch::nn::Embedding WordEmbeddingsImpl::get() torch::nn::Embedding WordEmbeddingsImpl::get()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment