From 2a7eb155f036b0ff0e632ca5e3c34b6b495faafc Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 17 Nov 2020 14:16:37 +0100
Subject: [PATCH] Corrected error in computing of entropy (changed dot product
 to hadamard product)

---
 torch_modules/src/NeuralNetwork.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index caf46e7..0a69bb3 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -7,9 +7,9 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
   if (probabilities.dim() != 1)
     util::myThrow("Invalid probabilities tensor");
 
-  probabilities = probabilities.unsqueeze(0);
-  auto logProbs = torch::clamp(torch::log(torch::transpose(probabilities, 0, 1)), -10.0, 10.0);
-  logProbs.index({torch::isnan(logProbs)}) = 0.0;
-  return - torch::tensordot(probabilities, logProbs, {0,1}, {0,1}).item<float>();
+  probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0);
+  float entropy = torch::sum(probabilities * torch::log(probabilities)).item<float>();
+
+  return entropy;
 }
 
-- 
GitLab