diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index aa0b0c2b4acc52510e6f13af019de059774b8373..291da6b5de7de8c8b0ff2b0b17a0007ee491a289 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); - auto prediction = machine.getClassifier()->getNN()(neuralInput); + auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; + float bestScore = std::numeric_limits<float>::min(); try { for (unsigned int i = 0; i < prediction.size(0); i++) - if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) + { + float score = prediction[i].item<float>(); + if (score > bestScore and machine.getTransitionSet().getTransition(i)->appliable(config)) chosenTransition = i; + } } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1)