Skip to content
Snippets Groups Projects
Decoder.cpp 5.67 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Decoder.hpp"
    #include "SubConfig.hpp"
    
    Decoder::Decoder(ReadingMachine & machine) : machine(machine)
    {
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      config.addPredicted(machine.getPredicted());
    
    
    Franck Dary's avatar
    Franck Dary committed
      try
      {
    
    Franck Dary's avatar
    Franck Dary committed
      config.setState(machine.getStrategy().getInitialState());
    
      while (true)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
    Franck Dary's avatar
    Franck Dary committed
        auto dictState = machine.getDict(config.getState()).getState();
    
        auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
    
    Franck Dary's avatar
    Franck Dary committed
        machine.getDict(config.getState()).setState(dictState);
    
    
    Franck Dary's avatar
    Franck Dary committed
        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
    
    Franck Dary's avatar
    Franck Dary committed
        auto prediction = machine.getClassifier()->getNN()(neuralInput);
    
        int chosenTransition = -1;
    
    
    Franck Dary's avatar
    Franck Dary committed
        try
        {
          for (unsigned int i = 0; i < prediction.size(0); i++)
            if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)->appliable(config))
              chosenTransition = i;
        } catch(std::exception & e) {util::myThrow(e.what());}
    
    Franck Dary's avatar
    Franck Dary committed
    
        if (chosenTransition == -1)
          util::myThrow("No transition appliable !");
    
        auto * transition = machine.getTransitionSet().getTransition(chosenTransition);
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
    
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        if (!config.moveWordIndex(movement.second))
          util::myThrow("Cannot move word index !");
      }
    
    Franck Dary's avatar
    Franck Dary committed
      } catch(std::exception & e) {util::myThrow(e.what());}
    
    
      // Force EOS when needed
      if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
      {
        Action shift = Action::pushWordIndexOnStack();
        shift.apply(config, shift);
        machine.getTransitionSet().getTransition("EOS")->apply(config);
        if (debug)
          fmt::print(stderr, "Forcing EOS transition\n");
      }
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      auto found = evaluation.find(metric);
    
      if (found == evaluation.end())
        util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));
    
      return found->second[scoreIndex];
    }
    
    
    float Decoder::getPrecision(const std::string & metric) const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      return getMetricScore(metric, 0);
    }
    
    
    float Decoder::getRecall(const std::string & metric) const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      return getMetricScore(metric, 1);
    }
    
    
    float Decoder::getF1Score(const std::string & metric) const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      return getMetricScore(metric, 2);
    }
    
    
    float Decoder::getAlignedAcc(const std::string & metric) const
    
    Franck Dary's avatar
    Franck Dary committed
    {
      return getMetricScore(metric, 3);
    }
    
    
    std::vector<std::pair<float,std::string>> Decoder::getF1Scores(const std::set<std::string> & colNames) const
    
    {
      return getScores(colNames, &Decoder::getF1Score);
    }
    
    
    std::vector<std::pair<float,std::string>> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const
    
    {
      return getScores(colNames, &Decoder::getAlignedAcc);
    }
    
    
    std::vector<std::pair<float,std::string>> Decoder::getRecalls(const std::set<std::string> & colNames) const
    
    {
      return getScores(colNames, &Decoder::getRecall);
    }
    
    
    std::vector<std::pair<float,std::string>> Decoder::getPrecisions(const std::set<std::string> & colNames) const
    
    {
      return getScores(colNames, &Decoder::getPrecision);
    }
    
    
    std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const
    
      std::vector<std::pair<float, std::string>> scores;
    
        scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName)));
    
    
      return scores; 
    }
    
    std::string Decoder::getMetricOfColName(const std::string & colName) const
    {
      if (colName == "HEAD")
        return "UAS";
      if (colName == "DEPREL")
        return "LAS";
      if (colName == "EOS")
        return "Sentences";
    
      return colName;
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
    {
      evaluation.clear();
      auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
      std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
      config.print(predictedTSVFile);
      std::fclose(predictedTSVFile);
    
      std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");
    
      char buffer[1024];
      while (!std::feof(evalFromUD))
      {
        if (buffer != std::fgets(buffer, 1024, evalFromUD))
          break;
    
    Franck Dary's avatar
    Franck Dary committed
        if (buffer[std::strlen(buffer)-1] == '\n')
          buffer[std::strlen(buffer)-1] = '\0';
        if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
          continue;
    
        if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
    Franck Dary's avatar
    Franck Dary committed
          auto metric = util::strip(sm[1]);
          for (unsigned int i = 0; i < this->evaluation[metric].size(); i++)
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
    Franck Dary's avatar
    Franck Dary committed
            auto value = util::strip(sm[i+2]);
            if (value.empty())
    
    Franck Dary's avatar
    Franck Dary committed
            {
    
    Franck Dary's avatar
    Franck Dary committed
              this->evaluation[metric][i] = 0.0;
    
    Franck Dary's avatar
    Franck Dary committed
              continue;
            }
    
    Franck Dary's avatar
    Franck Dary committed
            try {this->evaluation[metric][i] = std::stof(value);}
    
    Franck Dary's avatar
    Franck Dary committed
            catch (std::exception &)
    
    Franck Dary's avatar
    Franck Dary committed
            {
    
    Franck Dary's avatar
    Franck Dary committed
              util::myThrow(fmt::format("score '{}' is not a number in line '{}'", value, buffer));
    
    Franck Dary's avatar
    Franck Dary committed
            }
          }
    
    Franck Dary's avatar
    Franck Dary committed
        })){}
      }
    
      pclose(evalFromUD);
    }