Commit e35487c9 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bug

parent 3c41224b
......@@ -49,7 +49,8 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0);
auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
std::vector<std::pair<float, int>> scoresOfTransitions;
for (unsigned int i = 0; i < prediction.size(0); i++)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment