Select Git revision
Classifier.cpp
Franck Dary authored
Classifier.cpp 3.36 KiB
#include "Classifier.hpp"
#include "File.hpp"
#include "util.hpp"
Classifier::Classifier(const std::string & filename)
{
auto badFormatAndAbort = [&filename](const char * errInfo)
{
fprintf(stderr, "ERROR (%s) : file %s bad format. Aborting.\n", errInfo, filename.c_str());
exit(1);
};
File file(filename, "r");
FILE * fd = file.getDescriptor();
char buffer[1024];
if(fscanf(fd, "Name : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
name = buffer;
if(fscanf(fd, "Type : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
type = str2type(buffer);
if(fscanf(fd, "Feature Model : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
fm.reset(new FeatureModel(buffer));
if(fscanf(fd, "Action Set : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
as.reset(new ActionSet(buffer));
if(fscanf(fd, "Oracle : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
oracle = Oracle::getOracle(buffer);
}
Classifier::Type Classifier::str2type(const std::string & s)
{
if(s == "Prediction")
return Type::Prediction;
else if (s == "Information")
return Type::Information;
else if (s == "Forced")
return Type::Forced;
fprintf(stderr, "ERROR (%s) : invalid type \'%s\'. Aborting.\n", ERRINFO, s.c_str());
exit(1);
return Type::Prediction;
}
Classifier::WeightedActions Classifier::weightActions(Config & config, const std::string & goldAction)
{
initClassifier(config);
int actionIndex = as->getActionIndex(goldAction);
auto fd = fm->getFeatureDescription(config);
auto scores = mlp->predict(fd, actionIndex);
WeightedActions result;
for (unsigned int i = 0; i < scores.size(); i++)
result.emplace_back(scores[i], as->actions[i].name);
std::sort(result.begin(), result.end(),
[](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b)
{
return a.first > b.first;
});
return result;
}
void Classifier::initClassifier(Config & config)
{
if(mlp.get())
return;
int nbInputs = 0;
int nbHidden = 200;
int nbOutputs = as->actions.size();
auto fd = fm->getFeatureDescription(config);
for (auto feat : fd.values)
nbInputs += feat.vec->size();
mlp.reset(new MLP({{nbInputs, nbHidden, 0.0, MLP::Activation::RELU},
{nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
}
FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config)
{
return fm->getFeatureDescription(config);
}
std::string Classifier::getOracleAction(Config & config)
{
return oracle->getAction(config);
}
int Classifier::getOracleActionIndex(Config & config)
{
return as->getActionIndex(oracle->getAction(config));
}
int Classifier::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end)
{
return mlp->trainOnBatch(start, end);
}
std::string Classifier::getActionName(int actionIndex)
{
return as->getActionName(actionIndex);
}
void Classifier::printWeightedActions(FILE * output, WeightedActions & wa)
{
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
for (auto it : wa)
fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}