From 66b09b73bfc89ec129599c40bc66bdcab12fde0c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 19 Apr 2022 18:17:41 +0200
Subject: [PATCH] Ignoring values close to zero in entropy computation

---
 torch_modules/src/NeuralNetwork.cpp | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index fe3727d..28564b6 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -7,8 +7,15 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
   if (probabilities.dim() != 1)
     util::myThrow("Invalid probabilities tensor");
 
-  probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0);
-  float entropy = -torch::sum(probabilities * torch::log(probabilities)).item<float>();
+  float entropy = 0.0;
+  for (unsigned int i = 0; i < probabilities.size(0); i++)
+  {
+    if (probabilities[i].item<float>() > 0.01)
+      entropy -= (probabilities[i] * torch::log(probabilities[i])).item<float>();
+  }
+
+  if (entropy < 0.01)
+    entropy = 0.0;
 
   return entropy;
 }
-- 
GitLab