Something went wrong on our end
Select Git revision
Trainer.hpp
-
Franck Dary authoredFranck Dary authored
Trainer.hpp 3.03 KiB
/// @file Trainer.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2018-08-03
#ifndef TRAINER__H
#define TRAINER__H
#include "TransitionMachine.hpp"
#include "BD.hpp"
#include "Config.hpp"
/// @brief An object capable of training a TransitionMachine given a BD initialized with training examples.
class Trainer
{
private :
/// @brief The TransitionMachine that will be trained.
TransitionMachine & tm;
/// @brief The BD initialized with training examples.
BD & trainBD;
/// @brief The configuration of the TransitionMachine while processing trainBD.
Config & trainConfig;
/// @brief The BD initialized with dev examples.
///
/// Can be nullptr if dev is not used in this training.
BD * devBD;
/// @brief The configuration of the TransitionMachine while processing devBD.
///
/// Can be nullptr if dev is not used in this training.
Config * devConfig;
/// @brief For each classifier, a pair of number examples seen / number examples successfully classified
std::map< std::string, std::pair<int, int> > trainCounter;
/// @brief For each classifier, the train score for the current iteration.
std::map< std::string, float > scores;
/// @brief For each classifier, the best score seen on dev.
std::map< std::string, float > bestScores;
/// @brief Whether or not each Classifier topology has been printed.
std::map< std::string, bool > topologyPrinted;
/// @brief Current iteration.
int curIter;
public :
/// @brief The FeatureDescritpion of a Config.
using FD = FeatureModel::FeatureDescription;
private :
/// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score.
void printScoresAndSave(FILE * output);
/// @brief Get the scores of the classifiers on the dev dataset.
///
/// @return Map from each Classifier name to their score.
std::map<std::string, float> getScoreOnDev();
public :
/// @brief Construct a new Trainer without a dev set.
///
/// @param tm The TransitionMachine to use.
/// @param bd The BD to use.
/// @param config The config to use.
Trainer(TransitionMachine & tm, BD & bd, Config & config);
/// @brief Construct a new Trainer with a dev set.
///
/// @param tm The TransitionMachine to use.
/// @param bd The BD corresponding to the training dataset.
/// @param config The Config corresponding to bd.
/// @param devBD The BD corresponding to the dev dataset.
/// @param devConfig The Config corresponding to devBD.
Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig);
/// @brief Train the TransitionMachine one example at a time.
///
/// For each epoch all the Classifier of the TransitionMachine are fed all the
/// training examples, at the end of the epoch Classifier are evaluated on
/// the devBD if available, and each Classifier will be saved only if its score
/// on the current epoch is its all time best.\n
/// When a Classifier is saved that way, all the Dict involved are also saved.
void train();
};
#endif