From bbe7862dd74e07c8acff55fd2a17a1afec9c6066 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 25 Feb 2020 11:31:56 +0100
Subject: [PATCH] squeezing the prediction of the Decoder because LSTM predict
 2D tensor

---
 decoder/src/Decoder.cpp | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index aa0b0c2..291da6b 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
     machine.getDict(config.getState()).setState(dictState);
 
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
-    auto prediction = machine.getClassifier()->getNN()(neuralInput);
+    auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
 
     int chosenTransition = -1;
+    float bestScore = std::numeric_limits<float>::min();
 
     try
     {
       for (unsigned int i = 0; i < prediction.size(0); i++)
-        if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config))
+      {
+        float score = prediction[i].item<float>();
+        if (score > bestScore and machine.getTransitionSet().getTransition(i)->appliable(config))
           chosenTransition = i;
+      }
     } catch(std::exception & e) {util::myThrow(e.what());}
 
     if (chosenTransition == -1)
-- 
GitLab