Skip to content
Snippets Groups Projects
Decoder.cpp 3.6 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Decoder.hpp"
#include "SubConfig.hpp"

Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}

void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
  config.setState(machine.getStrategy().getInitialState());

  while (true)
  {
    auto dictState = machine.getDict(config.getState()).getState();
    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
    machine.getDict(config.getState()).setState(dictState);

    //TODO : check if clone is mandatory
    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone();
    //TODO : check if NoGradGuard does anything
    torch::NoGradGuard guard;
    auto prediction = machine.getClassifier()->getNN()(neuralInput);

    int chosenTransition = -1;

    for (unsigned int i = 0; i < prediction.size(0); i++)
      if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i))
        chosenTransition = i;

    if (chosenTransition == -1)
      util::myThrow("No transition appliable !");

    auto * transition = machine.getTransitionSet().getTransition(chosenTransition);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");
  }
}

float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
{
  auto found = evaluation.find(metric);

  if (found == evaluation.end())
    util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));

  return found->second[scoreIndex];
}

float Decoder::getPrecision(const std::string & metric)
{
  return getMetricScore(metric, 0);
}

float Decoder::getRecall(const std::string & metric)
{
  return getMetricScore(metric, 1);
}

float Decoder::getF1Score(const std::string & metric)
{
  return getMetricScore(metric, 2);
}

float Decoder::getAlignedAcc(const std::string & metric)
{
  return getMetricScore(metric, 3);
}

void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
  evaluation.clear();
  auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
  std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
  config.print(predictedTSVFile);
  std::fclose(predictedTSVFile);

  std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");

  char buffer[1024];
  while (!std::feof(evalFromUD))
  {
    if (buffer != std::fgets(buffer, 1024, evalFromUD))
      break;
Franck Dary's avatar
Franck Dary committed
    if (buffer[std::strlen(buffer)-1] == '\n')
      buffer[std::strlen(buffer)-1] = '\0';
    if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
      continue;

    if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
Franck Dary's avatar
Franck Dary committed
    {
      for (unsigned int i = 0; i < this->evaluation[sm[1]].size(); i++)
Franck Dary's avatar
Franck Dary committed
      {
        if (std::string(sm[i+2]).empty())
        {
          this->evaluation[sm[1]][i] = 0.0;
          continue;
        }
        try {this->evaluation[sm[1]][i] = std::stof(sm[i+2]);}
        catch (std::exception &) 
        {
          util::myThrow(fmt::format("score '{}' is not a number in line '{}'", std::string(sm[i+2]), buffer));
        }
      }
Franck Dary's avatar
Franck Dary committed
    })){}
  }

  pclose(evalFromUD);
}