Skip to content
Snippets Groups Projects
Commit bbe7862d authored by Franck Dary's avatar Franck Dary
Browse files

squeezing the prediction of the Decoder because LSTM predict 2D tensor

parent e8cb9812
No related branches found
No related tags found
No related merge requests found
...@@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) ...@@ -24,15 +24,19 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
machine.getDict(config.getState()).setState(dictState); machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
auto prediction = machine.getClassifier()->getNN()(neuralInput); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1; int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
try try
{ {
for (unsigned int i = 0; i < prediction.size(0); i++) for (unsigned int i = 0; i < prediction.size(0); i++)
if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config)) {
float score = prediction[i].item<float>();
if (score > bestScore and machine.getTransitionSet().getTransition(i)->appliable(config))
chosenTransition = i; chosenTransition = i;
}
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
if (chosenTransition == -1) if (chosenTransition == -1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment