From 5888ee1f1554b90f94e8e6db01ae81f4c8f97ba8 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 23 Oct 2020 17:32:27 +0200
Subject: [PATCH] Using torch functions to compute entropy

---
 decoder/src/Decoder.cpp             |  2 +-
 torch_modules/src/NeuralNetwork.cpp | 11 ++++-------
 2 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 22f9689..7739b1d 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -43,7 +43,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
 
   if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
   {
-    machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig, 0.0);
+    machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
     if (debug)
     {
       fmt::print(stderr, "Forcing EOS transition\n");
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index dbab2eb..acc5ad5 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -4,13 +4,10 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCU
 
 float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
 {
-  float res = 0.0;
-  for (unsigned int i = 0; i < probabilities.size(0); i++)
-  {
-    float val = probabilities[i].item<float>();
-    res -= val * log(val);
-  }
+  if (probabilities.dim() != 1)
+    util::myThrow("Invalid probabilities tensor");
 
-  return res;
+  probabilities = probabilities.unsqueeze(0);
+  return - torch::tensordot(probabilities, torch::log(torch::transpose(probabilities, 0, 1)), {0,1}, {0,1}).item<float>();
 }
 
-- 
GitLab