/// @file Trainer.hpp /// @author Franck Dary /// @version 1.0 /// @date 2018-08-03 #ifndef TRAINER__H #define TRAINER__H #include "TransitionMachine.hpp" #include "BD.hpp" #include "Config.hpp" #include "TrainInfos.hpp" /// @brief An object capable of training a TransitionMachine given a BD initialized with training examples. class Trainer { private : struct EndOfIteration : public std::exception { const char * what() const throw() { return "Iteration must end because an oracle could not find a zero-cost action."; } }; struct EndOfTraining : public std::exception { const char * what() const throw() { return "Training must end because every epoch has happened."; } }; private : /// @brief The TransitionMachine that will be trained. TransitionMachine & tm; /// @brief The BD initialized with training examples. BD & trainBD; /// @brief The configuration of the TransitionMachine while processing trainBD. Config & trainConfig; /// @brief The BD initialized with dev examples. /// /// Can be nullptr if dev is not used in this training. BD * devBD; /// @brief The configuration of the TransitionMachine while processing devBD. /// /// Can be nullptr if dev is not used in this training. Config * devConfig; /// @brief Lots of informations about the current training. TrainInfos TI; /// @brief Number of training steps done so far. int nbSteps; /// @brief Number of Actions taken so far. int nbActions; /// @brief Number of Actions needed to compute speed. int nbActionsCutoff; /// @brief Current training speed in actions per second. float currentSpeed; /// @brief The date the last time the speed has been computed. std::chrono::time_point<std::chrono::high_resolution_clock> pastTime; public : /// @brief The FeatureDescritpion of a Config. using FD = FeatureModel::FeatureDescription; private : /// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score. void printScoresAndSave(FILE * output); /// @brief Get the scores of the classifiers on the dev dataset. void computeScoreOnDev(); /// @brief Read the input file again and shuffle it. void resetAndShuffle(); /// @brief Run the current classifier and take the next transition, no training. void doStepNoTrain(); /// @brief Run the current classifier and take the next transition, training the classifier. void doStepTrain(); /// @brief Compute and print dev scores, increase epoch counter. void prepareNextEpoch(); public : /// @brief Construct a new Trainer without a dev set. /// /// @param tm The TransitionMachine to use. /// @param bd The BD to use. /// @param config The config to use. Trainer(TransitionMachine & tm, BD & bd, Config & config); /// @brief Construct a new Trainer with a dev set. /// /// @param tm The TransitionMachine to use. /// @param bd The BD corresponding to the training dataset. /// @param config The Config corresponding to bd. /// @param devBD The BD corresponding to the dev dataset. /// @param devConfig The Config corresponding to devBD. Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig); /// @brief Train the TransitionMachine one example at a time. /// /// For each epoch all the Classifier of the TransitionMachine are fed all the /// training examples, at the end of the epoch Classifier are evaluated on /// the devBD if available, and each Classifier will be saved only if its score /// on the current epoch is its all time best.\n /// When a Classifier is saved that way, all the Dict involved are also saved. void train(); }; #endif