From 66b09b73bfc89ec129599c40bc66bdcab12fde0c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 19 Apr 2022 18:17:41 +0200 Subject: [PATCH] Ignoring values close to zero in entropy computation --- torch_modules/src/NeuralNetwork.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index fe3727d..28564b6 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -7,8 +7,15 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) if (probabilities.dim() != 1) util::myThrow("Invalid probabilities tensor"); - probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0); - float entropy = -torch::sum(probabilities * torch::log(probabilities)).item<float>(); + float entropy = 0.0; + for (unsigned int i = 0; i < probabilities.size(0); i++) + { + if (probabilities[i].item<float>() > 0.01) + entropy -= (probabilities[i] * torch::log(probabilities[i])).item<float>(); + } + + if (entropy < 0.01) + entropy = 0.0; return entropy; } -- GitLab