diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index acc5ad557f4ea0d4b8ccb878cf0ba77aecea47ed..caf46e7964ed56535653496c55b95e4d5768e252 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -8,6 +8,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities) util::myThrow("Invalid probabilities tensor"); probabilities = probabilities.unsqueeze(0); - return - torch::tensordot(probabilities, torch::log(torch::transpose(probabilities, 0, 1)), {0,1}, {0,1}).item<float>(); + auto logProbs = torch::clamp(torch::log(torch::transpose(probabilities, 0, 1)), -10.0, 10.0); + logProbs.index({torch::isnan(logProbs)}) = 0.0; + return - torch::tensordot(probabilities, logProbs, {0,1}, {0,1}).item<float>(); }