From 4818cfc87227320e31577af565bfe5eca7cccb92 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 12 Jan 2022 14:13:38 +0100
Subject: [PATCH] Giving MCD to the eval script

---
 decoder/include/Decoder.hpp | 2 +-
 decoder/src/Decoder.cpp     | 4 ++--
 trainer/src/MacaonTrain.cpp | 2 +-
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index fe8c870..8604e18 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -26,7 +26,7 @@ class Decoder
 
   Decoder(ReadingMachine & machine);
   std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
-  void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
+  void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted, std::string mcd);
   std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
   std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
   std::vector<std::pair<float,std::string>> getRecalls(const std::set<std::string> & colNames) const;
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 5394280..076d718 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -135,7 +135,7 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const
   return colName;
 }
 
-void Decoder::evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted)
+void Decoder::evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted, std::string mcd)
 {
   evaluation.clear();
   auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
@@ -144,7 +144,7 @@ void Decoder::evaluate(const std::vector<const Config *> & configs, std::filesys
     configs[i]->print(predictedTSVFile, i==0);
   std::fclose(predictedTSVFile);
 
-  std::FILE * evalFromUD = popen(fmt::format("{} {} {} -x {}", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV, util::join(",", std::vector<std::string>(predicted.begin(), predicted.end()))).c_str(), "r");
+  std::FILE * evalFromUD = popen(fmt::format("{} {} {} -x {} --mcd {}", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV, util::join(",", std::vector<std::string>(predicted.begin(), predicted.end())), mcd).c_str(), "r");
 
   char buffer[1024];
   while (!std::feof(evalFromUD))
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index cef2bea..6efa3c6 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -345,7 +345,7 @@ int MacaonTrain::main()
       std::vector<const Config *> devConfigsPtrs;
       for (auto & devConfig : devConfigs)
         devConfigsPtrs.emplace_back(&devConfig);
-      decoder.evaluate(devConfigsPtrs, modelPath, devTsvFile, machine.getPredicted());
+      decoder.evaluate(devConfigsPtrs, modelPath, devTsvFile, machine.getPredicted(), mcd);
       devScores = decoder.getF1Scores(machine.getPredicted());
     }
     else
-- 
GitLab