Skip to content
Snippets Groups Projects
Commit 1d29dbf6 authored by Franck Dary's avatar Franck Dary
Browse files

Added cost of errors to errorAnalysis

parent 359dd3b5
Branches
No related tags found
No related merge requests found
...@@ -70,7 +70,8 @@ void Decoder::decode() ...@@ -70,7 +70,8 @@ void Decoder::decode()
for (auto & s : zeroCostActions) for (auto & s : zeroCostActions)
if (action->name == s) if (action->name == s)
oAction = s; oAction = s;
errors.add({action->name, oAction, weightedActions}); int actionCost = classifier->getActionCost(config, action->name);
errors.add({action->name, oAction, weightedActions, actionCost});
} }
action->apply(config); action->apply(config);
......
...@@ -22,14 +22,16 @@ class Error ...@@ -22,14 +22,16 @@ class Error
int indexOfGold; int indexOfGold;
int distanceWithGold; int distanceWithGold;
float entropy; float entropy;
int cost;
public : public :
Error(std::string &, std::string &, Classifier::WeightedActions &); Error(std::string &, std::string &, Classifier::WeightedActions &, int cost);
bool isError() const; bool isError() const;
const std::string & getType() const; const std::string & getType() const;
bool goldWasAtDistance(int distance) const; bool goldWasAtDistance(int distance) const;
float getEntropy() const; float getEntropy() const;
float getCost() const;
}; };
class ErrorSequence class ErrorSequence
......
#include "Error.hpp" #include "Error.hpp"
Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions) : Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions, int cost) :
prediction(prediction), gold(gold), weightedActions(weightedActions) prediction(prediction), gold(gold), weightedActions(weightedActions), cost(cost)
{ {
type = prediction + "->" + gold; type = prediction + "->" + gold;
if (ProgramParameters::onlyPrefixes) if (ProgramParameters::onlyPrefixes)
...@@ -49,6 +49,11 @@ float Error::getEntropy() const ...@@ -49,6 +49,11 @@ float Error::getEntropy() const
return entropy; return entropy;
} }
float Error::getCost() const
{
return cost;
}
void ErrorSequence::add(const Error & error) void ErrorSequence::add(const Error & error)
{ {
sequence.emplace_back(error); sequence.emplace_back(error);
...@@ -125,12 +130,13 @@ void Errors::printStats() ...@@ -125,12 +130,13 @@ void Errors::printStats()
{ {
nbFirstErrorOccurencesByType[error.getType()]++; nbFirstErrorOccurencesByType[error.getType()]++;
nbFirstErrorsTotal++; nbFirstErrorsTotal++;
nbFirstErrorIntroduced[error.getType()] += error.getCost();
for (unsigned int i = index+1; i < sequence.getSequence().size(); i++) for (unsigned int i = index+1; i < sequence.getSequence().size(); i++)
if (sequence.getSequence()[i].isError()) if (sequence.getSequence()[i].isError())
{ {
if ((int)(i - index) > window && window) if ((int)(i - index) > window && window)
break; break;
nbFirstErrorIntroduced[error.getType()] += 1; nbFirstErrorIntroduced[error.getType()] += sequence.getSequence()[i].getCost();
} }
} }
for (unsigned int i = minDistanceToCheck; i <= maxDistanceToCheck; i++) for (unsigned int i = minDistanceToCheck; i <= maxDistanceToCheck; i++)
......
...@@ -217,7 +217,7 @@ int main(int argc, char * argv[]) ...@@ -217,7 +217,7 @@ int main(int argc, char * argv[])
if (configIsError) if (configIsError)
{ {
errors.add({action->name, zeroCostActions[0], weightedActions}); errors.add({action->name, zeroCostActions[0], weightedActions, classifier->getActionCost(config, action->name)});
} }
} }
......
...@@ -63,6 +63,13 @@ class Classifier ...@@ -63,6 +63,13 @@ class Classifier
public : public :
/// @brief Return how many errors will an action introduce.
///
/// @param config The current config.
/// @param action The action to test.
///
/// @return The number of errors that the action will introduce.
int getActionCost(Config & config, const std::string & action);
/// @brief Print the weight that has been given to each Action by this Classifier. /// @brief Print the weight that has been given to each Action by this Classifier.
/// ///
/// @param output Where to print. /// @param output Where to print.
......
...@@ -237,6 +237,11 @@ void Classifier::printTopology(FILE * output) ...@@ -237,6 +237,11 @@ void Classifier::printTopology(FILE * output)
mlp->printTopology(output); mlp->printTopology(output);
} }
int Classifier::getActionCost(Config & config, const std::string & action)
{
return oracle->getActionCost(config, action);
}
std::vector<std::string> Classifier::getZeroCostActions(Config & config) std::vector<std::string> Classifier::getZeroCostActions(Config & config)
{ {
std::vector<std::string> result; std::vector<std::string> result;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment