Skip to content
Snippets Groups Projects
Commit 1d29dbf6 authored by Franck Dary's avatar Franck Dary
Browse files

Added cost of errors to errorAnalysis

parent 359dd3b5
Branches
No related tags found
No related merge requests found
......@@ -70,7 +70,8 @@ void Decoder::decode()
for (auto & s : zeroCostActions)
if (action->name == s)
oAction = s;
errors.add({action->name, oAction, weightedActions});
int actionCost = classifier->getActionCost(config, action->name);
errors.add({action->name, oAction, weightedActions, actionCost});
}
action->apply(config);
......
......@@ -22,14 +22,16 @@ class Error
int indexOfGold;
int distanceWithGold;
float entropy;
int cost;
public :
Error(std::string &, std::string &, Classifier::WeightedActions &);
Error(std::string &, std::string &, Classifier::WeightedActions &, int cost);
bool isError() const;
const std::string & getType() const;
bool goldWasAtDistance(int distance) const;
float getEntropy() const;
float getCost() const;
};
class ErrorSequence
......
#include "Error.hpp"
Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions) :
prediction(prediction), gold(gold), weightedActions(weightedActions)
Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions, int cost) :
prediction(prediction), gold(gold), weightedActions(weightedActions), cost(cost)
{
type = prediction + "->" + gold;
if (ProgramParameters::onlyPrefixes)
......@@ -49,6 +49,11 @@ float Error::getEntropy() const
return entropy;
}
float Error::getCost() const
{
return cost;
}
void ErrorSequence::add(const Error & error)
{
sequence.emplace_back(error);
......@@ -125,12 +130,13 @@ void Errors::printStats()
{
nbFirstErrorOccurencesByType[error.getType()]++;
nbFirstErrorsTotal++;
nbFirstErrorIntroduced[error.getType()] += error.getCost();
for (unsigned int i = index+1; i < sequence.getSequence().size(); i++)
if (sequence.getSequence()[i].isError())
{
if ((int)(i - index) > window && window)
break;
nbFirstErrorIntroduced[error.getType()] += 1;
nbFirstErrorIntroduced[error.getType()] += sequence.getSequence()[i].getCost();
}
}
for (unsigned int i = minDistanceToCheck; i <= maxDistanceToCheck; i++)
......
......@@ -217,7 +217,7 @@ int main(int argc, char * argv[])
if (configIsError)
{
errors.add({action->name, zeroCostActions[0], weightedActions});
errors.add({action->name, zeroCostActions[0], weightedActions, classifier->getActionCost(config, action->name)});
}
}
......
......@@ -63,6 +63,13 @@ class Classifier
public :
/// @brief Return how many errors will an action introduce.
///
/// @param config The current config.
/// @param action The action to test.
///
/// @return The number of errors that the action will introduce.
int getActionCost(Config & config, const std::string & action);
/// @brief Print the weight that has been given to each Action by this Classifier.
///
/// @param output Where to print.
......
......@@ -237,6 +237,11 @@ void Classifier::printTopology(FILE * output)
mlp->printTopology(output);
}
int Classifier::getActionCost(Config & config, const std::string & action)
{
return oracle->getActionCost(config, action);
}
std::vector<std::string> Classifier::getZeroCostActions(Config & config)
{
std::vector<std::string> result;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment