From 92e75482f2f8e9239bc7e0b5c2f7286ddbb83c76 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 11 Apr 2020 11:37:32 +0200 Subject: [PATCH] improved debug print when decoding --- decoder/src/Decoder.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f08bd03..5c81d84 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -37,6 +37,21 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); + if (debug) + { + auto softmaxed = torch::softmax(prediction,-1); + std::vector<std::pair<float,std::string>> toPrint; + for (unsigned int i = 0; i < softmaxed.size(0); i++) + { + float score = softmaxed[i].item<float>(); + std::string nicePrint = fmt::format("{} {:7.2f} {}", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + toPrint.emplace_back(std::make_pair(score,nicePrint)); + } + std::sort(toPrint.rbegin(), toPrint.rend()); + for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++) + fmt::print(stderr, "{}\n", toPrint[i].second); + } + try { for (unsigned int i = 0; i < prediction.size(0); i++) @@ -47,10 +62,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool chosenTransition = i; bestScore = score; } - if (debug) - { - fmt::print(stderr, "{} {:7.2f} {}\n", machine.getTransitionSet().getTransition(i)->appliable(config) ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); - } } } catch(std::exception & e) {util::myThrow(e.what());} -- GitLab