Newer
Older
#include "Decoder.hpp"
#include "SubConfig.hpp"
Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
torch::AutoGradMode useGrad(false);
machine.getClassifier()->getNN()->train(false);
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now();
config.setState(machine.getStrategy().getInitialState());
while (true)
{
auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
float bestScore = std::numeric_limits<float>::min();
try
{
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
} catch(std::exception & e) {util::myThrow(e.what());}
auto * transition = machine.getTransitionSet().getTransition(chosenTransition);
transition->apply(config);
config.addToHistory(transition->getName());
if (printAdvancement)
if (++nbExamplesProcessed >= printInterval)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
fmt::print(stderr, "\r{:80}\rdecoding... speed={:<6}ex/s\r", "", (int)(nbExamplesProcessed/seconds));
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
}
} catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
Action shift = Action::pushWordIndexOnStack();
shift.apply(config, shift);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug)
fmt::print(stderr, "Forcing EOS transition\n");
}
Franck Dary
committed
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
try {config.addMissingColumns();}
catch (std::exception & e) {util::myThrow(e.what());}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
{
auto found = evaluation.find(metric);
if (found == evaluation.end())
util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));
return found->second[scoreIndex];
}
float Decoder::getPrecision(const std::string & metric) const
float Decoder::getRecall(const std::string & metric) const
float Decoder::getF1Score(const std::string & metric) const
float Decoder::getAlignedAcc(const std::string & metric) const
std::vector<std::pair<float,std::string>> Decoder::getF1Scores(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getF1Score);
}
std::vector<std::pair<float,std::string>> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getAlignedAcc);
}
std::vector<std::pair<float,std::string>> Decoder::getRecalls(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getRecall);
}
std::vector<std::pair<float,std::string>> Decoder::getPrecisions(const std::set<std::string> & colNames) const
{
return getScores(colNames, &Decoder::getPrecision);
}
std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const
{
std::vector<std::pair<float, std::string>> scores;
for (auto & colName : colNames)
Franck Dary
committed
if (colName != Config::idColName)
scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName)));
return scores;
}
std::string Decoder::getMetricOfColName(const std::string & colName) const
{
if (colName == "HEAD")
return "UAS";
if (colName == "DEPREL")
return "LAS";
if (colName == "EOS")
return "Sentences";
if (colName == "FEATS")
return "UFeats";
Franck Dary
committed
if (colName == "FORM")
return "Words";
return colName;
}
void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
evaluation.clear();
auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
config.print(predictedTSVFile);
std::fclose(predictedTSVFile);
std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");
char buffer[1024];
while (!std::feof(evalFromUD))
{
if (buffer != std::fgets(buffer, 1024, evalFromUD))
break;
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
continue;
if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
auto metric = util::strip(sm[1]);
for (unsigned int i = 0; i < this->evaluation[metric].size(); i++)
auto value = util::strip(sm[i+2]);
if (value.empty())
try {this->evaluation[metric][i] = std::stof(value);}
util::myThrow(fmt::format("score '{}' is not a number in line '{}'", value, buffer));