From bbe7862dd74e07c8acff55fd2a17a1afec9c6066 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 25 Feb 2020 11:31:56 +0100 Subject: [PATCH] squeezing the prediction of the Decoder because LSTM predict 2D tensor --- decoder/src/Decoder.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index aa0b0c2..291da6b 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); - auto prediction = machine.getClassifier()->getNN()(neuralInput); + auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; + float bestScore = std::numeric_limits<float>::min(); try { for (unsigned int i = 0; i < prediction.size(0); i++) - if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) + { + float score = prediction[i].item<float>(); + if (score > bestScore and machine.getTransitionSet().getTransition(i)->appliable(config)) chosenTransition = i; + } } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1) -- GitLab