#include "Decoder.hpp" #include "SubConfig.hpp" Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } void Decoder::decode(BaseConfig & config, std::size_t beamSize) { try { config.setState(machine.getStrategy().getInitialState()); fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); while (true) { auto dictState = machine.getDict(config.getState()).getState(); auto context = config.extractContext(5,5,machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); auto prediction = machine.getClassifier()->getNN()(neuralInput); int chosenTransition = -1; for (unsigned int i = 0; i < prediction.size(0); i++) if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)) chosenTransition = i; if (chosenTransition == -1) util::myThrow("No transition appliable !"); auto * transition = machine.getTransitionSet().getTransition(chosenTransition); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) util::myThrow("Cannot move word index !"); } } catch(std::exception & e) {util::myThrow(e.what());} } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) { auto found = evaluation.find(metric); if (found == evaluation.end()) util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : "")); return found->second[scoreIndex]; } float Decoder::getPrecision(const std::string & metric) { return getMetricScore(metric, 0); } float Decoder::getRecall(const std::string & metric) { return getMetricScore(metric, 1); } float Decoder::getF1Score(const std::string & metric) { return getMetricScore(metric, 2); } float Decoder::getAlignedAcc(const std::string & metric) { return getMetricScore(metric, 3); } void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV) { evaluation.clear(); auto predictedTSV = (modelPath/"predicted_dev.tsv").string(); std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w"); config.print(predictedTSVFile); std::fclose(predictedTSVFile); std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r"); char buffer[1024]; while (!std::feof(evalFromUD)) { if (buffer != std::fgets(buffer, 1024, evalFromUD)) break; if (buffer[std::strlen(buffer)-1] == '\n') buffer[std::strlen(buffer)-1] = '\0'; if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){})) continue; if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm) { auto metric = util::strip(sm[1]); for (unsigned int i = 0; i < this->evaluation[metric].size(); i++) { auto value = util::strip(sm[i+2]); if (value.empty()) { this->evaluation[metric][i] = 0.0; continue; } try {this->evaluation[metric][i] = std::stof(value);} catch (std::exception &) { util::myThrow(fmt::format("score '{}' is not a number in line '{}'", value, buffer)); } } })){} } pclose(evalFromUD); }