Skip to content
Snippets Groups Projects
Select Git revision
  • 7fb76c8d4e540db4312c0fcdfdcc10c8f02a63fc
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.2
  • 0.0.1
12 results

index.rst

  • Trainer.cpp 3.26 KiB
    #include "Trainer.hpp"
    
    Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
    : tm(tm), mcd(mcd), config(config)
    {
    }
    
    void Trainer::trainUnbatched()
    {
      int nbIter = 5;
    
      fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
    
      for (int i = 0; i < nbIter; i++)
      {
        std::map< std::string, std::pair<int, int> > nbExamples;
    
        while (!config.isFinal())
        {
          TapeMachine::State * currentState = tm.getCurrentState();
          Classifier * classifier = currentState->classifier;
    
          //config.printForDebug(stderr);
    
          //fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
    
          std::string neededActionName = classifier->getOracleAction(config);
          auto weightedActions = classifier->weightActions(config, neededActionName);
          //printWeightedActions(stderr, weightedActions);
          std::string & predictedAction = weightedActions[0].second;
    
          nbExamples[classifier->name].first++;
          if(predictedAction == neededActionName)
            nbExamples[classifier->name].second++;
    
          //fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
    
          TapeMachine::Transition * transition = tm.getTransition(neededActionName);
          tm.takeTransition(transition);
          config.moveHead(transition->headMvt);
        }
    
        fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
        for(auto & it : nbExamples)
          fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
    
        config.reset();
      }
    }
    
    void Trainer::trainBatched()
    {
      using FD = FeatureModel::FeatureDescription;
      using Example = std::pair<int, FD>;
    
      std::map<Classifier*, std::vector<Example> > examples;
    
      while (!config.isFinal())
      {
        TapeMachine::State * currentState = tm.getCurrentState();
        Classifier * classifier = currentState->classifier;
        classifier->initClassifier(config);
    
        int neededActionIndex = classifier->getOracleActionIndex(config);
        std::string neededActionName = classifier->getActionName(neededActionIndex);
    
        examples[classifier].emplace_back(Example(neededActionIndex, classifier->getFeatureDescription(config)));
    
        TapeMachine::Transition * transition = tm.getTransition(neededActionName);
        tm.takeTransition(transition);
        config.moveHead(transition->headMvt);
      }
    
      int nbIter = 5;
      int batchSize = 256;
    
      for (int i = 0; i < nbIter; i++)
      {
        std::map< std::string, std::pair<int, int> > nbExamples;
    
        for(auto & it : examples)
        {
          int nbBatches = (it.second.size() / batchSize) + (it.second.size() % batchSize ? 1 : 0);
    
          for(int numBatch = 0; numBatch < nbBatches; numBatch++)
          {
            int currentBatchSize = std::min<int>(batchSize, it.second.size() - (numBatch * batchSize));
    
            auto batchStart = it.second.begin() + numBatch * batchSize;
            auto batchEnd = batchStart + currentBatchSize;
    
            int nbCorrect = it.first->trainOnBatch(batchStart, batchEnd);
    
            nbExamples[it.first->name].first += currentBatchSize;
            nbExamples[it.first->name].second += nbCorrect;
          }
        }
    
        fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
        for(auto & it : nbExamples)
          fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
    
      }
    }
    
    void Trainer::train()
    {
    //  trainUnbatched();
      trainBatched();
    }