diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index e8f7ecf17b88cefad3c49d48f0e86a014d3e7ef9..e576e0ed50183510b69e2ec693a50ac2e084a45c 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -12,16 +12,25 @@ class Decoder ReadingMachine & machine; std::map<std::string, std::array<float,4>> evaluation; + private : + + std::string getMetricOfColName(const std::string & colName) const; + std::vector<float> getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const; + float getMetricScore(const std::string & metric, std::size_t scoreIndex) const; + float getPrecision(const std::string & metric) const; + float getF1Score(const std::string & metric) const; + float getRecall(const std::string & metric) const; + float getAlignedAcc(const std::string & metric) const; + public : Decoder(ReadingMachine & machine); void decode(BaseConfig & config, std::size_t beamSize, bool debug); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); - float getMetricScore(const std::string & metric, std::size_t scoreIndex); - float getPrecision(const std::string & metric); - float getF1Score(const std::string & metric); - float getRecall(const std::string & metric); - float getAlignedAcc(const std::string & metric); + std::vector<float> getF1Scores(const std::set<std::string> & colNames) const; + std::vector<float> getAlignedAccs(const std::set<std::string> & colNames) const; + std::vector<float> getRecalls(const std::set<std::string> & colNames) const; + std::vector<float> getPrecisions(const std::set<std::string> & colNames) const; }; #endif diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 78c3b86103aa72aa646e33af363f1ffd07d2c71d..08f4de8d984b3ea6d258b3d3756829de275a29ff 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -50,7 +50,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) } catch(std::exception & e) {util::myThrow(e.what());} } -float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) +float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const { auto found = evaluation.find(metric); @@ -60,26 +60,68 @@ float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex return found->second[scoreIndex]; } -float Decoder::getPrecision(const std::string & metric) +float Decoder::getPrecision(const std::string & metric) const { return getMetricScore(metric, 0); } -float Decoder::getRecall(const std::string & metric) +float Decoder::getRecall(const std::string & metric) const { return getMetricScore(metric, 1); } -float Decoder::getF1Score(const std::string & metric) +float Decoder::getF1Score(const std::string & metric) const { return getMetricScore(metric, 2); } -float Decoder::getAlignedAcc(const std::string & metric) +float Decoder::getAlignedAcc(const std::string & metric) const { return getMetricScore(metric, 3); } +std::vector<float> Decoder::getF1Scores(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getF1Score); +} + +std::vector<float> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getAlignedAcc); +} + +std::vector<float> Decoder::getRecalls(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getRecall); +} + +std::vector<float> Decoder::getPrecisions(const std::set<std::string> & colNames) const +{ + return getScores(colNames, &Decoder::getPrecision); +} + +std::vector<float> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const +{ + std::vector<float> scores; + + for (auto & colName : colNames) + scores.push_back((this->*metric2score)(getMetricOfColName(colName))); + + return scores; +} + +std::string Decoder::getMetricOfColName(const std::string & colName) const +{ + if (colName == "HEAD") + return "UAS"; + if (colName == "DEPREL") + return "LAS"; + if (colName == "EOS") + return "Sentences"; + + return colName; +} + void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV) { evaluation.clear(); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 62d09a61cfadedb6b0e91ae27daf38c102025945..fff80b3f7b39ac90805a92844bcf01d74aeacf55 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -96,17 +96,27 @@ int main(int argc, char * argv[]) fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); decoder.decode(devConfig, 1, debug); decoder.evaluate(devConfig, modelPath, devTsvFile); - float devScore = decoder.getF1Score("UPOS"); - bool saved = devScore > bestDevScore; + std::vector<float> devScores = decoder.getF1Scores(machine.getPredicted()); + std::string devScoresStr = ""; + float devScoreMean = 0; + for (auto & score : devScores) + { + devScoresStr += fmt::format("{:5.2f}%,", score); + devScoreMean += score; + } + if (!devScoresStr.empty()) + devScoresStr.pop_back(); + devScoreMean /= devScores.size(); + bool saved = devScoreMean > bestDevScore; if (saved) { - bestDevScore = devScore; + bestDevScore = devScoreMean; machine.save(); } if (debug) - fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); else - fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); } return 0;