#include "Decoder.hpp" #include "SubConfig.hpp" Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { torch::AutoGradMode useGrad(false); machine.getClassifier()->getNN()->train(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; int nbExamplesProcessed = 0; auto pastTime = std::chrono::high_resolution_clock::now(); try { config.setState(machine.getStrategy().getInitialState()); while (true) { if (debug) config.printForDebug(stderr); auto dictState = machine.getDict(config.getState()).getState(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); try { for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) { chosenTransition = i; bestScore = score; } } } catch(std::exception & e) {util::myThrow(e.what());} if (chosenTransition == -1) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } auto * transition = machine.getTransitionSet().getTransition(chosenTransition); transition->apply(config); config.addToHistory(transition->getName()); if (printAdvancement) if (++nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; fmt::print(stderr, "\r{:80}\rdecoding... speed={:<6}ex/s\r", "", (int)(nbExamplesProcessed/seconds)); nbExamplesProcessed = 0; } auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) util::myThrow("Cannot move word index !"); } } catch(std::exception & e) {util::myThrow(e.what());} // Force EOS when needed if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) { Action shift = Action::pushWordIndexOnStack(); shift.apply(config, shift); machine.getTransitionSet().getTransition("EOS")->apply(config); if (debug) fmt::print(stderr, "Forcing EOS transition\n"); } // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script try {config.addMissingColumns();} catch (std::exception & e) {util::myThrow(e.what());} } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const { auto found = evaluation.find(metric); if (found == evaluation.end()) util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : "")); return found->second[scoreIndex]; } float Decoder::getPrecision(const std::string & metric) const { return getMetricScore(metric, 0); } float Decoder::getRecall(const std::string & metric) const { return getMetricScore(metric, 1); } float Decoder::getF1Score(const std::string & metric) const { return getMetricScore(metric, 2); } float Decoder::getAlignedAcc(const std::string & metric) const { return getMetricScore(metric, 3); } std::vector<std::pair<float,std::string>> Decoder::getF1Scores(const std::set<std::string> & colNames) const { return getScores(colNames, &Decoder::getF1Score); } std::vector<std::pair<float,std::string>> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const { return getScores(colNames, &Decoder::getAlignedAcc); } std::vector<std::pair<float,std::string>> Decoder::getRecalls(const std::set<std::string> & colNames) const { return getScores(colNames, &Decoder::getRecall); } std::vector<std::pair<float,std::string>> Decoder::getPrecisions(const std::set<std::string> & colNames) const { return getScores(colNames, &Decoder::getPrecision); } std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const { std::vector<std::pair<float, std::string>> scores; for (auto & colName : colNames) if (colName != Config::idColName) scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName))); return scores; } std::string Decoder::getMetricOfColName(const std::string & colName) const { if (colName == "HEAD") return "UAS"; if (colName == "DEPREL") return "LAS"; if (colName == "EOS") return "Sentences"; if (colName == "FEATS") return "UFeats"; if (colName == "FORM") return "Words"; return colName; } void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV) { evaluation.clear(); auto predictedTSV = (modelPath/"predicted_dev.tsv").string(); std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w"); config.print(predictedTSVFile); std::fclose(predictedTSVFile); std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r"); char buffer[1024]; while (!std::feof(evalFromUD)) { if (buffer != std::fgets(buffer, 1024, evalFromUD)) break; if (buffer[std::strlen(buffer)-1] == '\n') buffer[std::strlen(buffer)-1] = '\0'; if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){})) continue; if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm) { auto metric = util::strip(sm[1]); for (unsigned int i = 0; i < this->evaluation[metric].size(); i++) { auto value = util::strip(sm[i+2]); if (value.empty()) { this->evaluation[metric][i] = 0.0; continue; } try {this->evaluation[metric][i] = std::stof(value);} catch (std::exception &) { util::myThrow(fmt::format("score '{}' is not a number in line '{}'", value, buffer)); } } })){} } pclose(evalFromUD); }