diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index caf46e7964ed56535653496c55b95e4d5768e252..0a69bb34d55ad46430575687232107e557c6ce7e 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -7,9 +7,9 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) if (probabilities.dim() != 1) util::myThrow("Invalid probabilities tensor"); - probabilities = probabilities.unsqueeze(0); - auto logProbs = torch::clamp(torch::log(torch::transpose(probabilities, 0, 1)), -10.0, 10.0); - logProbs.index({torch::isnan(logProbs)}) = 0.0; - return - torch::tensordot(probabilities, logProbs, {0,1}, {0,1}).item<float>(); + probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0); + float entropy = torch::sum(probabilities * torch::log(probabilities)).item<float>(); + + return entropy; }