From 7be95e0f04ee15d5d56f7978d9a6d09692f98135 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 23 Oct 2020 22:12:17 +0200
Subject: [PATCH] Remove NaN and clamping of entropy

---
 torch_modules/src/NeuralNetwork.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index acc5ad5..caf46e7 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -8,6 +8,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
     util::myThrow("Invalid probabilities tensor");
 
   probabilities = probabilities.unsqueeze(0);
-  return - torch::tensordot(probabilities, torch::log(torch::transpose(probabilities, 0, 1)), {0,1}, {0,1}).item<float>();
+  auto logProbs = torch::clamp(torch::log(torch::transpose(probabilities, 0, 1)), -10.0, 10.0);
+  logProbs.index({torch::isnan(logProbs)}) = 0.0;
+  return - torch::tensordot(probabilities, logProbs, {0,1}, {0,1}).item<float>();
 }
 
-- 
GitLab