diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 52fd11d4424a8f39c2b5a064faa916452d8218e2..9f5ab48dc1654a6faac0a37f0478b2a8dd103ae8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -70,7 +70,8 @@ void Decoder::decode() for (auto & s : zeroCostActions) if (action->name == s) oAction = s; - errors.add({action->name, oAction, weightedActions}); + int actionCost = classifier->getActionCost(config, action->name); + errors.add({action->name, oAction, weightedActions, actionCost}); } action->apply(config); diff --git a/error_correction/include/Error.hpp b/error_correction/include/Error.hpp index 5f9f8d82dc5bec350a85f1e07f5d4ca2f873487e..da1259a173243f87ea8c28825eee044bf259004d 100644 --- a/error_correction/include/Error.hpp +++ b/error_correction/include/Error.hpp @@ -22,14 +22,16 @@ class Error int indexOfGold; int distanceWithGold; float entropy; + int cost; public : - Error(std::string &, std::string &, Classifier::WeightedActions &); + Error(std::string &, std::string &, Classifier::WeightedActions &, int cost); bool isError() const; const std::string & getType() const; bool goldWasAtDistance(int distance) const; float getEntropy() const; + float getCost() const; }; class ErrorSequence diff --git a/error_correction/src/Error.cpp b/error_correction/src/Error.cpp index 93fabe99260712dcb5b8705ee3b2b8da4dcb7e04..97a33d07f8c1be992a726613fc04d32a5c22d15a 100644 --- a/error_correction/src/Error.cpp +++ b/error_correction/src/Error.cpp @@ -1,7 +1,7 @@ #include "Error.hpp" -Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions) : -prediction(prediction), gold(gold), weightedActions(weightedActions) +Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions, int cost) : +prediction(prediction), gold(gold), weightedActions(weightedActions), cost(cost) { type = prediction + "->" + gold; if (ProgramParameters::onlyPrefixes) @@ -49,6 +49,11 @@ float Error::getEntropy() const return entropy; } +float Error::getCost() const +{ + return cost; +} + void ErrorSequence::add(const Error & error) { sequence.emplace_back(error); @@ -125,12 +130,13 @@ void Errors::printStats() { nbFirstErrorOccurencesByType[error.getType()]++; nbFirstErrorsTotal++; + nbFirstErrorIntroduced[error.getType()] += error.getCost(); for (unsigned int i = index+1; i < sequence.getSequence().size(); i++) if (sequence.getSequence()[i].isError()) { if ((int)(i - index) > window && window) break; - nbFirstErrorIntroduced[error.getType()] += 1; + nbFirstErrorIntroduced[error.getType()] += sequence.getSequence()[i].getCost(); } } for (unsigned int i = minDistanceToCheck; i <= maxDistanceToCheck; i++) diff --git a/error_correction/src/macaon_error_correction.cpp b/error_correction/src/macaon_error_correction.cpp index 4e40ab110282a4334706bb19dfa6726f31c0199d..1e45d9773114331c5dfbd655e7c06fb11c209ef5 100644 --- a/error_correction/src/macaon_error_correction.cpp +++ b/error_correction/src/macaon_error_correction.cpp @@ -217,7 +217,7 @@ int main(int argc, char * argv[]) if (configIsError) { - errors.add({action->name, zeroCostActions[0], weightedActions}); + errors.add({action->name, zeroCostActions[0], weightedActions, classifier->getActionCost(config, action->name)}); } } diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index b672702ebcac4b81fc39191ec8240f332d10c686..b7f7dfc7f7ef455a3876c0b0b0924ab655048d27 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -63,6 +63,13 @@ class Classifier public : + /// @brief Return how many errors will an action introduce. + /// + /// @param config The current config. + /// @param action The action to test. + /// + /// @return The number of errors that the action will introduce. + int getActionCost(Config & config, const std::string & action); /// @brief Print the weight that has been given to each Action by this Classifier. /// /// @param output Where to print. diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 78f6f540a0112718ac159d180336dba7203dcb67..e422c7de7ac42284a6b9810bc3b7bc6113809c47 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -237,6 +237,11 @@ void Classifier::printTopology(FILE * output) mlp->printTopology(output); } +int Classifier::getActionCost(Config & config, const std::string & action) +{ + return oracle->getActionCost(config, action); +} + std::vector<std::string> Classifier::getZeroCostActions(Config & config) { std::vector<std::string> result;