From 7e1a67891600a52ec03cea7e99eaf687dbae4d32 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 24 Mar 2021 10:01:10 +0100
Subject: [PATCH] When constructing WordEmbedding, only use max_norm when
 necessary

---
 torch_modules/src/WordEmbeddings.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp
index 20f7721..c931d6d 100644
--- a/torch_modules/src/WordEmbeddings.cpp
+++ b/torch_modules/src/WordEmbeddings.cpp
@@ -6,7 +6,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
 
 WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
 {
-  embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
+  if (maxNorm == std::numeric_limits<float>::max())
+    embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
+  else
+    embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
 }
 
 torch::nn::Embedding WordEmbeddingsImpl::get()
-- 
GitLab