diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f08bd0370905c626b75bc04e24200de84c116bd5..5c81d8466ba3fba514ef82437e8ca6cab1865a28 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -37,6 +37,21 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); + if (debug) + { + auto softmaxed = torch::softmax(prediction,-1); + std::vector<std::pair<float,std::string>> toPrint; + for (unsigned int i = 0; i < softmaxed.size(0); i++) + { + float score = softmaxed[i].item<float>(); + std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + toPrint.emplace_back(std::make_pair(score,nicePrint)); + } + std::sort(toPrint.rbegin(), toPrint.rend()); + for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) + fmt::print(stderr, "{}\n", toPrint[i].second); + } + try { for (unsigned int i = 0; i < prediction.size(0); i++) @@ -47,10 +62,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool chosenTransition = i; bestScore = score; } - if (debug) - { - fmt::print(stderr, "{} {:7.2f} {}\n", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); - } } } catch(std::exception & e) {util::myThrow(e.what());}