From f5196710b77c94f47a88847b6f2ae92a1c65da64 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 23 Mar 2019 18:06:31 +0100
Subject: [PATCH] Added an option to show dev gold accuracy

---
 maca_common/include/ProgramParameters.hpp |  1 +
 maca_common/src/ProgramParameters.cpp     |  1 +
 trainer/src/Trainer.cpp                   | 11 ++++++++++-
 trainer/src/macaon_train.cpp              |  2 ++
 4 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/maca_common/include/ProgramParameters.hpp b/maca_common/include/ProgramParameters.hpp
index 274d397..7534930 100644
--- a/maca_common/include/ProgramParameters.hpp
+++ b/maca_common/include/ProgramParameters.hpp
@@ -71,6 +71,7 @@ struct ProgramParameters
   static std::string tapeToMask;
   static float maskRate;
   static bool featureExtraction;
+  static bool devEvalOnGold;
 
   private :
 
diff --git a/maca_common/src/ProgramParameters.cpp b/maca_common/src/ProgramParameters.cpp
index ac15fa8..988120f 100644
--- a/maca_common/src/ProgramParameters.cpp
+++ b/maca_common/src/ProgramParameters.cpp
@@ -65,4 +65,5 @@ int ProgramParameters::dictCapacity;
 std::string ProgramParameters::tapeToMask;
 float ProgramParameters::maskRate;
 bool ProgramParameters::featureExtraction;
+bool ProgramParameters::devEvalOnGold;
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index cf0c844..f3076a4 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -79,7 +79,16 @@ void Trainer::computeScoreOnDev()
       if (pActionIsZeroCost)
         TI.addDevSuccess(tm.getCurrentClassifier()->name);
 
-      std::string actionName = pAction;
+      std::string actionName;
+      if (ProgramParameters::devEvalOnGold)
+      {
+        int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig);
+        actionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
+      }
+      else
+      {
+        actionName = pAction;
+      }
       Action * action = tm.getCurrentClassifier()->getAction(actionName);
 
       if (ProgramParameters::debug)
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 7885ac4..e303685 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -87,6 +87,7 @@ po::options_description getOptionsDescription()
       "The rate of elements of the Tape that will be masked.")
     ("printTime", "Print time on stderr.")
     ("featureExtraction", "Use macaon only a feature extractor, print corpus to stdout.")
+    ("devEvalOnGold", "If true, dev accuracy will be computed on gold configurations.")
     ("shuffle", po::value<bool>()->default_value(true),
       "Shuffle examples after each iteration");
 
@@ -270,6 +271,7 @@ int main(int argc, char * argv[])
   ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true;
   ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true;
   ProgramParameters::featureExtraction = vm.count("featureExtraction") == 0 ? false : true;
+  ProgramParameters::devEvalOnGold = vm.count("devEvalOnGold") == 0 ? false : true;
   ProgramParameters::trainName = vm["train"].as<std::string>();
   ProgramParameters::devName = vm["dev"].as<std::string>();
   ProgramParameters::lang = vm["lang"].as<std::string>();
-- 
GitLab