diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index fe3727dc18827225b17bbf6db9133b7dd67ce54e..28564b627dec2f8d1fb7ef80443ce2271285a257 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -7,8 +7,15 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) if (probabilities.dim() != 1) util::myThrow("Invalid probabilities tensor"); - probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0); - float entropy = -torch::sum(probabilities * torch::log(probabilities)).item<float>(); + float entropy = 0.0; + for (unsigned int i = 0; i < probabilities.size(0); i++) + { + if (probabilities[i].item<float>() > 0.01) + entropy -= (probabilities[i] * torch::log(probabilities[i])).item<float>(); + } + + if (entropy < 0.01) + entropy = 0.0; return entropy; }