diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index aa0b0c2b4acc52510e6f13af019de059774b8373..291da6b5de7de8c8b0ff2b0b17a0007ee491a289 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
     machine.getDict(config.getState()).setState(dictState);
 
     auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
-    auto prediction = machine.getClassifier()->getNN()(neuralInput);
+    auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
 
     int chosenTransition = -1;
+    float bestScore = std::numeric_limits<float>::min();
 
     try
     {
       for (unsigned int i = 0; i < prediction.size(0); i++)
-        if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config))
+      {
+        float score = prediction[i].item<float>();
+        if (score > bestScore and machine.getTransitionSet().getTransition(i)->appliable(config))
           chosenTransition = i;
+      }
     } catch(std::exception & e) {util::myThrow(e.what());}
 
     if (chosenTransition == -1)