diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 0a69bb34d55ad46430575687232107e557c6ce7e..785c8d9d1c7ecca0342405377a73c055827562ad 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -8,7 +8,7 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) util::myThrow("Invalid probabilities tensor"); probabilities = torch::clamp(probabilities.unsqueeze(0), 0.00000000001, 1.0); - float entropy = torch::sum(probabilities * torch::log(probabilities)).item<float>(); + float entropy = -torch::sum(probabilities * torch::log(probabilities)).item<float>(); return entropy; }