Select Git revision
index.rst
Trainer.cpp 3.26 KiB
#include "Trainer.hpp"
Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
: tm(tm), mcd(mcd), config(config)
{
}
void Trainer::trainUnbatched()
{
int nbIter = 5;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
for (int i = 0; i < nbIter; i++)
{
std::map< std::string, std::pair<int, int> > nbExamples;
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
//config.printForDebug(stderr);
//fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
std::string neededActionName = classifier->getOracleAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName);
//printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
nbExamples[classifier->name].first++;
if(predictedAction == neededActionName)
nbExamples[classifier->name].second++;
//fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
TapeMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition);
config.moveHead(transition->headMvt);
}
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
config.reset();
}
}
void Trainer::trainBatched()
{
using FD = FeatureModel::FeatureDescription;
using Example = std::pair<int, FD>;
std::map<Classifier*, std::vector<Example> > examples;
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
classifier->initClassifier(config);
int neededActionIndex = classifier->getOracleActionIndex(config);
std::string neededActionName = classifier->getActionName(neededActionIndex);
examples[classifier].emplace_back(Example(neededActionIndex, classifier->getFeatureDescription(config)));
TapeMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition);
config.moveHead(transition->headMvt);
}
int nbIter = 5;
int batchSize = 256;
for (int i = 0; i < nbIter; i++)
{
std::map< std::string, std::pair<int, int> > nbExamples;
for(auto & it : examples)
{
int nbBatches = (it.second.size() / batchSize) + (it.second.size() % batchSize ? 1 : 0);
for(int numBatch = 0; numBatch < nbBatches; numBatch++)
{
int currentBatchSize = std::min<int>(batchSize, it.second.size() - (numBatch * batchSize));
auto batchStart = it.second.begin() + numBatch * batchSize;
auto batchEnd = batchStart + currentBatchSize;
int nbCorrect = it.first->trainOnBatch(batchStart, batchEnd);
nbExamples[it.first->name].first += currentBatchSize;
nbExamples[it.first->name].second += nbCorrect;
}
}
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
}
}
void Trainer::train()
{
// trainUnbatched();
trainBatched();
}