From 1d29dbf6af0ca78c0f186868ead8a0f124e002b6 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@etu.univ-amu.fr>
Date: Fri, 21 Dec 2018 16:58:22 +0100
Subject: [PATCH] Added cost of errors to errorAnalysis

---
 decoder/src/Decoder.cpp                          |  3 ++-
 error_correction/include/Error.hpp               |  4 +++-
 error_correction/src/Error.cpp                   | 12 +++++++++---
 error_correction/src/macaon_error_correction.cpp |  2 +-
 transition_machine/include/Classifier.hpp        |  7 +++++++
 transition_machine/src/Classifier.cpp            |  5 +++++
 6 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 52fd11d..9f5ab48 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -70,7 +70,8 @@ void Decoder::decode()
       for (auto & s : zeroCostActions)
         if (action->name == s)
           oAction = s;
-      errors.add({action->name, oAction, weightedActions});
+      int actionCost = classifier->getActionCost(config, action->name);
+      errors.add({action->name, oAction, weightedActions, actionCost});
     }
 
     action->apply(config);
diff --git a/error_correction/include/Error.hpp b/error_correction/include/Error.hpp
index 5f9f8d8..da1259a 100644
--- a/error_correction/include/Error.hpp
+++ b/error_correction/include/Error.hpp
@@ -22,14 +22,16 @@ class Error
   int indexOfGold;
   int distanceWithGold;
   float entropy;
+  int cost;
   
   public :
 
-  Error(std::string &, std::string &, Classifier::WeightedActions &);
+  Error(std::string &, std::string &, Classifier::WeightedActions &, int cost);
   bool isError() const;
   const std::string & getType() const;
   bool goldWasAtDistance(int distance) const;
   float getEntropy() const;
+  float getCost() const;
 };
 
 class ErrorSequence
diff --git a/error_correction/src/Error.cpp b/error_correction/src/Error.cpp
index 93fabe9..97a33d0 100644
--- a/error_correction/src/Error.cpp
+++ b/error_correction/src/Error.cpp
@@ -1,7 +1,7 @@
 #include "Error.hpp"
 
-Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions) :
-prediction(prediction), gold(gold), weightedActions(weightedActions)
+Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions, int cost) :
+prediction(prediction), gold(gold), weightedActions(weightedActions), cost(cost)
 {
   type = prediction + "->" + gold;
   if (ProgramParameters::onlyPrefixes)
@@ -49,6 +49,11 @@ float Error::getEntropy() const
   return entropy;
 }
 
+float Error::getCost() const
+{
+  return cost;
+}
+
 void ErrorSequence::add(const Error & error)
 {
   sequence.emplace_back(error);
@@ -125,12 +130,13 @@ void Errors::printStats()
         {
           nbFirstErrorOccurencesByType[error.getType()]++;
           nbFirstErrorsTotal++;
+          nbFirstErrorIntroduced[error.getType()] += error.getCost();
           for (unsigned int i = index+1; i < sequence.getSequence().size(); i++)
             if (sequence.getSequence()[i].isError())
             {
               if ((int)(i - index) > window && window)
                 break;
-              nbFirstErrorIntroduced[error.getType()] += 1;
+              nbFirstErrorIntroduced[error.getType()] += sequence.getSequence()[i].getCost();
             }
         }
         for (unsigned int i = minDistanceToCheck; i <= maxDistanceToCheck; i++)
diff --git a/error_correction/src/macaon_error_correction.cpp b/error_correction/src/macaon_error_correction.cpp
index 4e40ab1..1e45d97 100644
--- a/error_correction/src/macaon_error_correction.cpp
+++ b/error_correction/src/macaon_error_correction.cpp
@@ -217,7 +217,7 @@ int main(int argc, char * argv[])
 
       if (configIsError)
       {
-        errors.add({action->name, zeroCostActions[0], weightedActions});
+        errors.add({action->name, zeroCostActions[0], weightedActions, classifier->getActionCost(config, action->name)});
       }
     }
 
diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp
index b672702..b7f7dfc 100644
--- a/transition_machine/include/Classifier.hpp
+++ b/transition_machine/include/Classifier.hpp
@@ -63,6 +63,13 @@ class Classifier
 
   public :
 
+  /// @brief Return how many errors will an action introduce.
+  ///
+  /// @param config The current config.
+  /// @param action The action to test.
+  ///
+  /// @return The number of errors that the action will introduce.
+  int getActionCost(Config & config, const std::string & action);
   /// @brief Print the weight that has been given to each Action by this Classifier.
   ///
   /// @param output Where to print.
diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp
index 78f6f54..e422c7d 100644
--- a/transition_machine/src/Classifier.cpp
+++ b/transition_machine/src/Classifier.cpp
@@ -237,6 +237,11 @@ void Classifier::printTopology(FILE * output)
   mlp->printTopology(output);
 }
 
+int Classifier::getActionCost(Config & config, const std::string & action)
+{
+  return oracle->getActionCost(config, action);
+}
+
 std::vector<std::string> Classifier::getZeroCostActions(Config & config)
 {
   std::vector<std::string> result;
-- 
GitLab