diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index 20f7721e1f0c0359667610aa768b099a63c3fe38..c931d6d43b327b661a5a53c76d588d5af997b93e 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -6,7 +6,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim) { - embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); + if (maxNorm == std::numeric_limits<float>::max()) + embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); + else + embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq))); } torch::nn::Embedding WordEmbeddingsImpl::get()