Skip to content
Snippets Groups Projects
Decoder.cpp 5.32 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Decoder.hpp"
#include "SubConfig.hpp"
Franck Dary's avatar
Franck Dary committed

Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}

void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
Franck Dary's avatar
Franck Dary committed
{
  constexpr int printInterval = 50;

  torch::AutoGradMode useGrad(false);
Franck Dary's avatar
Franck Dary committed
  machine.trainMode(false);
  machine.setDictsState(Dict::State::Closed);
Franck Dary's avatar
Franck Dary committed
  int nbExamplesProcessed = 0;
  auto pastTime = std::chrono::high_resolution_clock::now();

  Beam beam(beamSize, beamThreshold, baseConfig, machine);
Franck Dary's avatar
Franck Dary committed
  try
  {
      beam.update(machine, debug);
      if (printAdvancement)
        if (++nbExamplesProcessed >= printInterval)
          auto actualTime = std::chrono::high_resolution_clock::now();
          double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
          pastTime = actualTime;
          fmt::print(stderr, "\r{:80}\rdecoding... speed={:<6}ex/s\r", "", (int)(nbExamplesProcessed/seconds));
          nbExamplesProcessed = 0;
Franck Dary's avatar
Franck Dary committed
  } catch(std::exception & e) {util::myThrow(e.what());}
  baseConfig = beam[0].config;
  machine.getClassifier()->setState(baseConfig.getState());

  if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
    machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig);
    if (debug)
      fmt::print(stderr, "Forcing EOS transition\n");
      baseConfig.printForDebug(stderr);
  // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
  try {baseConfig.addMissingColumns();}
  catch (std::exception & e) {util::myThrow(e.what());}
Franck Dary's avatar
Franck Dary committed
}

float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
Franck Dary's avatar
Franck Dary committed
{
  auto found = evaluation.find(metric);

  if (found == evaluation.end())
    util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));

  return found->second[scoreIndex];
}

float Decoder::getPrecision(const std::string & metric) const
Franck Dary's avatar
Franck Dary committed
{
  return getMetricScore(metric, 0);
}

float Decoder::getRecall(const std::string & metric) const
Franck Dary's avatar
Franck Dary committed
{
  return getMetricScore(metric, 1);
}

float Decoder::getF1Score(const std::string & metric) const
Franck Dary's avatar
Franck Dary committed
{
  return getMetricScore(metric, 2);
}

float Decoder::getAlignedAcc(const std::string & metric) const
Franck Dary's avatar
Franck Dary committed
{
  return getMetricScore(metric, 3);
}

std::vector<std::pair<float,std::string>> Decoder::getF1Scores(const std::set<std::string> & colNames) const
{
  return getScores(colNames, &Decoder::getF1Score);
}

std::vector<std::pair<float,std::string>> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const
{
  return getScores(colNames, &Decoder::getAlignedAcc);
}

std::vector<std::pair<float,std::string>> Decoder::getRecalls(const std::set<std::string> & colNames) const
{
  return getScores(colNames, &Decoder::getRecall);
}

std::vector<std::pair<float,std::string>> Decoder::getPrecisions(const std::set<std::string> & colNames) const
{
  return getScores(colNames, &Decoder::getPrecision);
}

std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const
  std::vector<std::pair<float, std::string>> scores;
    if (colName != Config::idColName)
      scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName)));

  return scores; 
}

std::string Decoder::getMetricOfColName(const std::string & colName) const
{
  if (colName == "HEAD")
    return "UAS";
  if (colName == "DEPREL")
    return "LAS";
  if (colName == "EOS")
    return "Sentences";
  if (colName == "FEATS")
    return "UFeats";
Franck Dary's avatar
Franck Dary committed
void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
  evaluation.clear();
  auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
  std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
  config.print(predictedTSVFile);
  std::fclose(predictedTSVFile);

Franck Dary's avatar
Franck Dary committed
  std::FILE * evalFromUD = popen(fmt::format("{} {} {}", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");
Franck Dary's avatar
Franck Dary committed

  char buffer[1024];
  while (!std::feof(evalFromUD))
  {
    if (buffer != std::fgets(buffer, 1024, evalFromUD))
      break;
Franck Dary's avatar
Franck Dary committed
    if (buffer[std::strlen(buffer)-1] == '\n')
      buffer[std::strlen(buffer)-1] = '\0';
    if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
      continue;

    if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
Franck Dary's avatar
Franck Dary committed
    {
Franck Dary's avatar
Franck Dary committed
      auto metric = util::strip(sm[1]);
      for (unsigned int i = 0; i < this->evaluation[metric].size(); i++)
Franck Dary's avatar
Franck Dary committed
      {
Franck Dary's avatar
Franck Dary committed
        auto value = util::strip(sm[i+2]);
        if (value.empty())
Franck Dary's avatar
Franck Dary committed
        {
Franck Dary's avatar
Franck Dary committed
          this->evaluation[metric][i] = 0.0;
Franck Dary's avatar
Franck Dary committed
          continue;
        }
Franck Dary's avatar
Franck Dary committed
        try {this->evaluation[metric][i] = std::stof(value);}
Franck Dary's avatar
Franck Dary committed
        catch (std::exception &)
Franck Dary's avatar
Franck Dary committed
        {
Franck Dary's avatar
Franck Dary committed
          util::myThrow(fmt::format("score '{}' is not a number in line '{}'", value, buffer));
Franck Dary's avatar
Franck Dary committed
        }
      }
Franck Dary's avatar
Franck Dary committed
    })){}
  }

  pclose(evalFromUD);
}