Skip to content
Snippets Groups Projects
Select Git revision
  • b8e657ce3f87b8316e8bf95a9b22226749561a67
  • master default protected
  • fullUD
  • movementInAction
4 results

Classifier.cpp

Blame
  • hartbook's avatar
    Franck Dary authored
    b8e657ce
    History
    Classifier.cpp 3.36 KiB
    #include "Classifier.hpp"
    #include "File.hpp"
    #include "util.hpp"
    
    Classifier::Classifier(const std::string & filename)
    {
      auto badFormatAndAbort = [&filename](const char * errInfo)
      {
        fprintf(stderr, "ERROR (%s) : file %s bad format. Aborting.\n", errInfo, filename.c_str());
    
        exit(1);
      };
    
      File file(filename, "r");
      FILE * fd = file.getDescriptor();
    
      char buffer[1024];
    
      if(fscanf(fd, "Name : %s\n", buffer) != 1)
        badFormatAndAbort(ERRINFO);
    
      name = buffer;
    
      if(fscanf(fd, "Type : %s\n", buffer) != 1)
        badFormatAndAbort(ERRINFO);
    
      type = str2type(buffer);
    
      if(fscanf(fd, "Feature Model : %s\n", buffer) != 1)
        badFormatAndAbort(ERRINFO);
    
      fm.reset(new FeatureModel(buffer));
    
      if(fscanf(fd, "Action Set : %s\n", buffer) != 1)
        badFormatAndAbort(ERRINFO);
    
      as.reset(new ActionSet(buffer));
    
      if(fscanf(fd, "Oracle : %s\n", buffer) != 1)
        badFormatAndAbort(ERRINFO);
    
      oracle = Oracle::getOracle(buffer);
    }
    
    Classifier::Type Classifier::str2type(const std::string & s)
    {
      if(s == "Prediction")
        return Type::Prediction;
      else if (s == "Information")
        return Type::Information;
      else if (s == "Forced")
        return Type::Forced;
    
      fprintf(stderr, "ERROR (%s) : invalid type \'%s\'. Aborting.\n", ERRINFO, s.c_str());
      exit(1);
    
      return Type::Prediction;
    }
    
    Classifier::WeightedActions Classifier::weightActions(Config & config, const std::string & goldAction)
    {
      initClassifier(config);
    
      int actionIndex = as->getActionIndex(goldAction);
    
      auto fd = fm->getFeatureDescription(config);
      auto scores = mlp->predict(fd, actionIndex);
    
      WeightedActions result;
    
      for (unsigned int i = 0; i < scores.size(); i++)
        result.emplace_back(scores[i], as->actions[i].name);
    
      std::sort(result.begin(), result.end(),
        [](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b)
        {
          return a.first > b.first; 
        });
    
      return result;
    }
    
    void Classifier::initClassifier(Config & config)
    {
      if(mlp.get())
        return;
    
      int nbInputs = 0;
      int nbHidden = 200;
      int nbOutputs = as->actions.size();
    
      auto fd = fm->getFeatureDescription(config);
    
      for (auto feat : fd.values)
        nbInputs += feat.vec->size();
    
      mlp.reset(new MLP({{nbInputs, nbHidden, 0.0, MLP::Activation::RELU},
                         {nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
    }
    
    FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config)
    {
      return fm->getFeatureDescription(config);
    }
    
    std::string Classifier::getOracleAction(Config & config)
    {
      return oracle->getAction(config);
    }
    
    int Classifier::getOracleActionIndex(Config & config)
    {
      return as->getActionIndex(oracle->getAction(config));
    }
    
    int Classifier::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end)
    {
      return mlp->trainOnBatch(start, end);
    }
    
    std::string Classifier::getActionName(int actionIndex)
    {
      return as->getActionName(actionIndex);
    }
    
    void Classifier::printWeightedActions(FILE * output, WeightedActions & wa)
    {
      int nbCols = 80;
      char symbol = '-';
    
      for(int i = 0; i < nbCols; i++)
        fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
    
      for (auto it : wa)
        fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
    
      for(int i = 0; i < nbCols; i++)
        fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
    }