Skip to content
Snippets Groups Projects
Commit 4d4c6ba3 authored by Franck Dary's avatar Franck Dary
Browse files

Added a program parameter for error analysis

parent c89e9660
No related branches found
No related tags found
No related merge requests found
......@@ -67,6 +67,17 @@ void Decoder::decode()
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
{
auto zeroCostActions = classifier->getZeroCostActions(config);
if (zeroCostActions.empty())
{
fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
config.printForDebug(stderr);
for (auto & a : weightedActions)
{
fprintf(stderr, "%s : ", a.second.second.c_str());
Oracle::explainCostOfAction(stderr, config, a.second.second);
}
exit(1);
}
std::string oAction = zeroCostActions[0];
for (auto & s : zeroCostActions)
if (action->name == s)
......
......@@ -57,6 +57,8 @@ po::options_description getOptionsDescription()
("errorAnalysis", "Print an analysis of errors")
("meanEntropy", "Print the mean entropy for error types")
("onlyPrefixes", "Only uses the prefixes of error categories")
("nbErrorsToShow", po::value<int>()->default_value(10),
"Display only the X most common errors")
("classifier", po::value<std::string>()->default_value(""),
"Name of the monitored classifier, if not specified monitor everyone");
......@@ -120,6 +122,7 @@ int main(int argc, char * argv[])
ProgramParameters::mcdName = vm["mcd"].as<std::string>();
ProgramParameters::debug = vm.count("debug") == 0 ? false : true;
ProgramParameters::errorAnalysis = vm.count("errorAnalysis") == 0 ? false : true;
ProgramParameters::nbErrorsToShow = vm["nbErrorsToShow"].as<int>();
ProgramParameters::meanEntropy = vm.count("meanEntropy") == 0 ? false : true;
ProgramParameters::onlyPrefixes = vm.count("onlyPrefixes") == 0 ? false : true;
ProgramParameters::dicts = vm["dicts"].as<std::string>();
......
......@@ -89,8 +89,8 @@ void Errors::printStats()
{
unsigned int minDistanceToCheck = 1;
unsigned int maxDistanceToCheck = 5;
int window = 10;
int nbErrorsToKeep = 10;
int window = 20;
int nbErrorsToKeep = ProgramParameters::nbErrorsToShow;
std::map<std::string, int> nbErrorOccurencesByType;
std::map<std::string, int> nbFirstErrorOccurencesByType;
std::map<std::string, float> nbFirstErrorIntroduced;
......
......@@ -59,6 +59,7 @@ struct ProgramParameters
static bool meanEntropy;
static bool onlyPrefixes;
static std::map<std::string,std::string> featureModelByClassifier;
static int nbErrorsToShow;
private :
......
......@@ -53,3 +53,4 @@ std::string ProgramParameters::classifierName;
int ProgramParameters::batchSize;
std::string ProgramParameters::loss;
std::map<std::string,std::string> ProgramParameters::featureModelByClassifier;
int ProgramParameters::nbErrorsToShow;
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment