Select Git revision
ResultAnalysis.py
Classifier.cpp 7.01 KiB
#include "Classifier.hpp"
#include "File.hpp"
#include "util.hpp"
Classifier::Classifier(const std::string & filename, bool trainMode)
{
this->trainMode = trainMode;
auto badFormatAndAbort = [&filename](const char * errInfo)
{
fprintf(stderr, "ERROR (%s) : file %s bad format. Aborting.\n", errInfo, filename.c_str());
exit(1);
};
File file(ProgramParameters::expPath + filename, "r");
FILE * fd = file.getDescriptor();
char buffer[1024];
if(fscanf(fd, "Name : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
name = buffer;
if(fscanf(fd, "Type : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
type = str2type(buffer);
if(fscanf(fd, "Oracle : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
if(type != Type::Prediction)
{
char buffer2[1024];
if(type == Type::Information)
{
if(fscanf(fd, "Oracle Filename : %s\n", buffer2) != 1)
badFormatAndAbort(ERRINFO);
oracle = Oracle::getOracle(buffer, ProgramParameters::expPath + std::string("/") + buffer2);
}
else
oracle = Oracle::getOracle(buffer);
as.reset(new ActionSet(this->name + "_ActionSet", true));
return;
}
oracle = Oracle::getOracle(buffer);
if(fscanf(fd, "Feature Model : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
std::string fmFilename = ProgramParameters::expPath + buffer;
if (ProgramParameters::featureModelByClassifier.count(this->name))
fmFilename = ProgramParameters::featureModelByClassifier[this->name];
fm.reset(new FeatureModel(fmFilename));
if(fscanf(fd, "Action Set : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
as.reset(new ActionSet(ProgramParameters::expPath + buffer, false));
if(fscanf(fd, "Topology : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
topology = buffer;
}
Classifier::Type Classifier::str2type(const std::string & s)
{
if(s == "Prediction")
return Type::Prediction;
else if (s == "Information")
return Type::Information;
else if (s == "Forced")
return Type::Forced;
fprintf(stderr, "ERROR (%s) : invalid type \'%s\'. Aborting.\n", ERRINFO, s.c_str());
exit(1);
return Type::Prediction;
}
Classifier::WeightedActions Classifier::weightActions(Config & config)
{
WeightedActions result;
if(type == Type::Prediction)
{
initClassifier(config);
auto fd = fm->getFeatureDescription(config);
auto scores = mlp->predict(fd);
if (ProgramParameters::showFeatureRepresentation == 1)
fd.printForDebug(stderr);
for (unsigned int i = 0; i < scores.size(); i++)
result.emplace_back(as->actions[i].appliable(config), std::pair<float, std::string>(scores[i], as->actions[i].name));
std::sort(result.begin(), result.end(),
[](const std::pair< bool, std::pair<float, std::string> > & a, const std::pair< bool, std::pair<float, std::string> > & b)
{
return a.second.first > b.second.first;
});
}
else
result.emplace_back(true, std::pair<float, std::string>(1.0, getOracleAction(config)));
return result;
}
void Classifier::initClassifier(Config & config)
{
if(type != Type::Prediction)
return;
if(mlp.get())
return;
std::string modelFilename = ProgramParameters::expPath + name + ".model";
if (fileExists(modelFilename))
{
mlp.reset(new MLP(modelFilename));
Dict::initDicts(mlp->getModel(), name);
return;
}
mlp.reset(new MLP());
Dict::initDicts(mlp->getModel(), name);
auto fd = fm->getFeatureDescription(config);
int nbInputs = 0;
int nbOutputs = as->actions.size();
for (auto feat : fd.values)
for (auto dict : feat.dicts)
nbInputs += dict->getDimension();
mlp->init(nbInputs, topology, nbOutputs);
}
FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config)
{
if(type != Type::Prediction)
{
fprintf(stderr, "ERROR (%s) : classifier \'%s\' has no feature description. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
return fm->getFeatureDescription(config);
}
std::string Classifier::getOracleAction(Config & config)
{
return oracle->getAction(config);
}
int Classifier::getOracleActionIndex(Config & config)
{
return as->getActionIndex(oracle->getAction(config));
}
int Classifier::getActionIndex(const std::string & action)
{
return as->getActionIndex(action);
}
std::string Classifier::getActionName(int actionIndex)
{
return as->getActionName(actionIndex);
}
void Classifier::printWeightedActions(FILE * output, WeightedActions & wa, int threshhold)
{
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
bool oneActionWasPossible = false;
for(unsigned int i = 0; i < wa.size() && (int)i < threshhold; i++)
{
auto & it = wa[i];
bool thisActionIsPossible = it.first ? true : false;
oneActionWasPossible = oneActionWasPossible || thisActionIsPossible;
fprintf(output, "%s %6.2f %s\n", thisActionIsPossible ? "*" : " ", it.second.first, it.second.second.c_str());
}
if(!oneActionWasPossible)
for(unsigned int i = threshhold; i < wa.size() ;i++)
if(wa[i].first)
{
fprintf(output, "%s %6.2f %s\n", "*", wa[i].second.first, wa[i].second.second.c_str());
break;
}
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}
void Classifier::save(const std::string & filename)
{
if(type != Type::Prediction)
{
fprintf(stderr, "ERROR (%s) : classifier \'%s\' cannot be saved. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
mlp->save(filename);
}
Action * Classifier::getAction(const std::string & name)
{
return as->getAction(name);
}
bool Classifier::needsTrain()
{
return type == Type::Prediction;
}
void Classifier::printTopology(FILE * output)
{
fprintf(output, "%s topology : ", name.c_str());
mlp->printTopology(output);
}
int Classifier::getActionCost(Config & config, const std::string & action)
{
return oracle->getActionCost(config, action);
}
std::vector<std::string> Classifier::getZeroCostActions(Config & config)
{
std::vector<std::string> result;
for (Action & a : as->actions)
if (a.appliable(config) && oracle->getActionCost(config, a.name) == 0)
result.emplace_back(a.name);
if (result.empty() && as->hasDefaultAction)
result.emplace_back(as->getActionName(0));
return result;
}
float Classifier::trainOnExample(Config & config, int gold)
{
auto fd = fm->getFeatureDescription(config);
return mlp->update(fd, gold);
}
void Classifier::explainCostOfActions(FILE * output, Config & config)
{
for (Action & a : as->actions)
{
fprintf(output, "%s : ", a.name.c_str());
if (!a.appliable(config))
{
fprintf(output, "not appliable\n");
continue;
}
oracle->explainCostOfAction(output, config, a.name);
}
}
float Classifier::computeEntropy(WeightedActions & wa)
{
float entropy = 0.0;
for (unsigned int i = 0; i < 2 && i < wa.size(); i++)
{
auto it = wa.begin() + i;
entropy -= it->second.first - (it->second.first - wa[0].second.first);
}
return entropy;
}