From e35487c97db64e9cb28d94f3c1fbc6d589806692 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 9 Jul 2020 00:03:55 +0200 Subject: [PATCH] Corrected bug --- decoder/src/Beam.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index ee2fd0d..0675879 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -49,7 +49,8 @@ void Beam::update(ReadingMachine & machine, bool debug) auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0); + + auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++) -- GitLab