diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index ee2fd0df190e1a38c669caaad838db0ab54c90a7..067587979e64eb219fcd1955a0e8d76341de1f03 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -49,7 +49,8 @@ void Beam::update(ReadingMachine & machine, bool debug) auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0); + + auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++)