From 5370f9719dbafc7303146dbbb7d5d4feb6a984b8 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 14 Feb 2020 11:01:29 +0100 Subject: [PATCH] During training, dev scores correspond to what column machine is predictiong --- decoder/include/Decoder.hpp | 19 +++++++++---- decoder/src/Decoder.cpp | 52 ++++++++++++++++++++++++++++++++---- trainer/src/macaon_train.cpp | 20 ++++++++++---- 3 files changed, 76 insertions(+), 15 deletions(-) diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index e8f7ecf..e576e0e 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -12,16 +12,25 @@ class Decoder ReadingMachine & machine; std::map<std::string, std::array<float,4>> evaluation; + private : + + std::string getMetricOfColName(const std::string & colName) const; + std::vector<float> getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const; + float getMetricScore(const std::string & metric, std::size_t scoreIndex) const; + float getPrecision(const std::string & metric) const; + float getF1Score(const std::string & metric) const; + float getRecall(const std::string & metric) const; + float getAlignedAcc(const std::string & metric) const; + public : Decoder(ReadingMachine & machine); void decode(BaseConfig & config, std::size_t beamSize, bool debug); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); - float getMetricScore(const std::string & metric, std::size_t scoreIndex); - float getPrecision(const std::string & metric); - float getF1Score(const std::string & metric); - float getRecall(const std::string & metric); - float getAlignedAcc(const std::string & metric); + std::vector<float> getF1Scores(const std::set<std::string> & colNames) const; + std::vector<float> getAlignedAccs(const std::set<std::string> & colNames) const; + std::vector<float> getRecalls(const std::set<std::string> & colNames) const; + std::vector<float> getPrecisions(const std::set<std::string> & colNames) const; }; #endif diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 78c3b86..08f4de8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -50,7 +50,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) } catch(std::exception & e) {util::myThrow(e.what());} } -float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) +float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const { auto found = evaluation.find(metric); @@ -60,26 +60,68 @@ float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex return found->second[scoreIndex]; } -float Decoder::getPrecision(const std::string & metric) +float Decoder::getPrecision(const std::string & metric) const { return getMetricScore(metric, 0); } -float Decoder::getRecall(const std::string & metric) +float Decoder::getRecall(const std::string & metric) const { return getMetricScore(metric, 1); } -float Decoder::getF1Score(const std::string & metric) +float Decoder::getF1Score(const std::string & metric) const { return getMetricScore(metric, 2); } -float Decoder::getAlignedAcc(const std::string & metric) +float Decoder::getAlignedAcc(const std::string & metric) const { return getMetricScore(metric, 3); } +std::vector<float> Decoder::getF1Scores(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getF1Score); +} + +std::vector<float> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getAlignedAcc); +} + +std::vector<float> Decoder::getRecalls(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getRecall); +} + +std::vector<float> Decoder::getPrecisions(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getPrecision); +} + +std::vector<float> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const +{ + std::vector<float> scores; + + for (auto & colName : colNames) + scores.push_back((this->*metric2score)(getMetricOfColName(colName))); + + return scores; +} + +std::string Decoder::getMetricOfColName(const std::string & colName) const +{ + if (colName == "HEAD") + return "UAS"; + if (colName == "DEPREL") + return "LAS"; + if (colName == "EOS") + return "Sentences"; + + return colName; +} + void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV) { evaluation.clear(); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 62d09a6..fff80b3 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -96,17 +96,27 @@ int main(int argc, char * argv[]) fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); decoder.decode(devConfig, 1, debug); decoder.evaluate(devConfig, modelPath, devTsvFile); - float devScore = decoder.getF1Score("UPOS"); - bool saved = devScore > bestDevScore; + std::vector<float> devScores = decoder.getF1Scores(machine.getPredicted()); + std::string devScoresStr = ""; + float devScoreMean = 0; + for (auto & score : devScores) + { + devScoresStr += fmt::format("{:5.2f}%,", score); + devScoreMean += score; + } + if (!devScoresStr.empty()) + devScoresStr.pop_back(); + devScoreMean /= devScores.size(); + bool saved = devScoreMean > bestDevScore; if (saved) { - bestDevScore = devScore; + bestDevScore = devScoreMean; machine.save(); } if (debug) - fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); else - fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); } return 0; -- GitLab