From 5370f9719dbafc7303146dbbb7d5d4feb6a984b8 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 14 Feb 2020 11:01:29 +0100
Subject: [PATCH] During training, dev scores correspond to what column machine
 is predictiong

---
 decoder/include/Decoder.hpp  | 19 +++++++++----
 decoder/src/Decoder.cpp      | 52 ++++++++++++++++++++++++++++++++----
 trainer/src/macaon_train.cpp | 20 ++++++++++----
 3 files changed, 76 insertions(+), 15 deletions(-)

diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index e8f7ecf..e576e0e 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -12,16 +12,25 @@ class Decoder
   ReadingMachine & machine;
   std::map<std::string, std::array<float,4>> evaluation;
 
+  private :
+
+  std::string getMetricOfColName(const std::string & colName) const;
+  std::vector<float> getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const;
+  float getMetricScore(const std::string & metric, std::size_t scoreIndex) const;
+  float getPrecision(const std::string & metric) const;
+  float getF1Score(const std::string & metric) const;
+  float getRecall(const std::string & metric) const;
+  float getAlignedAcc(const std::string & metric) const;
+
   public :
 
   Decoder(ReadingMachine & machine);
   void decode(BaseConfig & config, std::size_t beamSize, bool debug);
   void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
-  float getMetricScore(const std::string & metric, std::size_t scoreIndex);
-  float getPrecision(const std::string & metric);
-  float getF1Score(const std::string & metric);
-  float getRecall(const std::string & metric);
-  float getAlignedAcc(const std::string & metric);
+  std::vector<float> getF1Scores(const std::set<std::string> & colNames) const;
+  std::vector<float> getAlignedAccs(const std::set<std::string> & colNames) const;
+  std::vector<float> getRecalls(const std::set<std::string> & colNames) const;
+  std::vector<float> getPrecisions(const std::set<std::string> & colNames) const;
 };
 
 #endif
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 78c3b86..08f4de8 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -50,7 +50,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
   } catch(std::exception & e) {util::myThrow(e.what());}
 }
 
-float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
+float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
 {
   auto found = evaluation.find(metric);
 
@@ -60,26 +60,68 @@ float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex
   return found->second[scoreIndex];
 }
 
-float Decoder::getPrecision(const std::string & metric)
+float Decoder::getPrecision(const std::string & metric) const
 {
   return getMetricScore(metric, 0);
 }
 
-float Decoder::getRecall(const std::string & metric)
+float Decoder::getRecall(const std::string & metric) const
 {
   return getMetricScore(metric, 1);
 }
 
-float Decoder::getF1Score(const std::string & metric)
+float Decoder::getF1Score(const std::string & metric) const
 {
   return getMetricScore(metric, 2);
 }
 
-float Decoder::getAlignedAcc(const std::string & metric)
+float Decoder::getAlignedAcc(const std::string & metric) const
 {
   return getMetricScore(metric, 3);
 }
 
+std::vector<float> Decoder::getF1Scores(const std::set<std::string> & colNames) const
+{
+  return getScores(colNames, &Decoder::getF1Score);
+}
+
+std::vector<float> Decoder::getAlignedAccs(const std::set<std::string> & colNames) const
+{
+  return getScores(colNames, &Decoder::getAlignedAcc);
+}
+
+std::vector<float> Decoder::getRecalls(const std::set<std::string> & colNames) const
+{
+  return getScores(colNames, &Decoder::getRecall);
+}
+
+std::vector<float> Decoder::getPrecisions(const std::set<std::string> & colNames) const
+{
+  return getScores(colNames, &Decoder::getPrecision);
+}
+
+std::vector<float> Decoder::getScores(const std::set<std::string> & colNames, float (Decoder::* metric2score)(const std::string &) const) const
+{
+  std::vector<float> scores;
+
+  for (auto & colName : colNames)
+    scores.push_back((this->*metric2score)(getMetricOfColName(colName)));
+
+  return scores; 
+}
+
+std::string Decoder::getMetricOfColName(const std::string & colName) const
+{
+  if (colName == "HEAD")
+    return "UAS";
+  if (colName == "DEPREL")
+    return "LAS";
+  if (colName == "EOS")
+    return "Sentences";
+
+  return colName;
+}
+
 void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
 {
   evaluation.clear();
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 62d09a6..fff80b3 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -96,17 +96,27 @@ int main(int argc, char * argv[])
       fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
     decoder.decode(devConfig, 1, debug);
     decoder.evaluate(devConfig, modelPath, devTsvFile);
-    float devScore = decoder.getF1Score("UPOS");
-    bool saved = devScore > bestDevScore;
+    std::vector<float> devScores = decoder.getF1Scores(machine.getPredicted());
+    std::string devScoresStr = "";
+    float devScoreMean = 0;
+    for (auto & score : devScores)
+    {
+      devScoresStr += fmt::format("{:5.2f}%,", score);
+      devScoreMean += score;
+    }
+    if (!devScoresStr.empty())
+      devScoresStr.pop_back();
+    devScoreMean /= devScores.size();
+    bool saved = devScoreMean > bestDevScore;
     if (saved)
     {
-      bestDevScore = devScore;
+      bestDevScore = devScoreMean;
       machine.save();
     }
     if (debug)
-      fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
+      fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
     else
-      fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
+      fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
   }
 
   return 0;
-- 
GitLab