Something went wrong on our end
Select Git revision
Trainer.hpp
-
Franck Dary authoredFranck Dary authored
Trainer.hpp 3.89 KiB
/// @file Trainer.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2018-08-03
#ifndef TRAINER__H
#define TRAINER__H
#include "TransitionMachine.hpp"
#include "BD.hpp"
#include "Config.hpp"
#include "TrainInfos.hpp"
/// @brief An object capable of training a TransitionMachine given a BD initialized with training examples.
class Trainer
{
private :
struct EndOfIteration : public std::exception
{
const char * what() const throw()
{
return "Iteration must end because an oracle could not find a zero-cost action.";
}
};
struct EndOfTraining : public std::exception
{
const char * what() const throw()
{
return "Training must end because every epoch has happened.";
}
};
public :
/// @brief The FeatureDescritpion of a Config.
using FD = FeatureModel::FeatureDescription;
private :
/// @brief The TransitionMachine that will be trained.
TransitionMachine & tm;
/// @brief The BD initialized with training examples.
BD & trainBD;
/// @brief The configuration of the TransitionMachine while processing trainBD.
Config & trainConfig;
/// @brief The BD initialized with dev examples.
///
/// Can be nullptr if dev is not used in this training.
BD * devBD;
/// @brief The configuration of the TransitionMachine while processing devBD.
///
/// Can be nullptr if dev is not used in this training.
Config * devConfig;
/// @brief Lots of informations about the current training.
TrainInfos TI;
/// @brief Number of training steps done so far.
int nbSteps;
/// @brief Number of Actions taken so far.
int nbActions;
/// @brief Number of Actions needed to compute speed.
int nbActionsCutoff;
/// @brief Current training speed in actions per second.
float currentSpeed;
/// @brief The date the last time the speed has been computed.
std::chrono::time_point<std::chrono::high_resolution_clock> pastTime;
/// @brief For each classifier, a FeatureDescription it needs to remember for a future update.
std::map<std::string,FD> pendingFD;
private :
/// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score.
void printScoresAndSave(FILE * output);
/// @brief Get the scores of the classifiers on the dev dataset.
void computeScoreOnDev();
/// @brief Read the input file again and shuffle it.
void resetAndShuffle();
/// @brief Run the current classifier and take the next transition, no training.
void doStepNoTrain();
/// @brief Run the current classifier and take the next transition, training the classifier.
void doStepTrain();
/// @brief Compute and print dev scores, increase epoch counter.
void prepareNextEpoch();
/// @brief Set the debug variable ProgramParameters::debug.
void setDebugValue();
public :
/// @brief Construct a new Trainer without a dev set.
///
/// @param tm The TransitionMachine to use.
/// @param bd The BD to use.
/// @param config The config to use.
Trainer(TransitionMachine & tm, BD & bd, Config & config);
/// @brief Construct a new Trainer with a dev set.
///
/// @param tm The TransitionMachine to use.
/// @param bd The BD corresponding to the training dataset.
/// @param config The Config corresponding to bd.
/// @param devBD The BD corresponding to the dev dataset.
/// @param devConfig The Config corresponding to devBD.
Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig);
/// @brief Train the TransitionMachine one example at a time.
///
/// For each epoch all the Classifier of the TransitionMachine are fed all the
/// training examples, at the end of the epoch Classifier are evaluated on
/// the devBD if available, and each Classifier will be saved only if its score
/// on the current epoch is its all time best.\n
/// When a Classifier is saved that way, all the Dict involved are also saved.
void train();
};
#endif