Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
  • fullUD
  • movementInAction
3 results

Trainer.hpp

Blame
  • Trainer.hpp 3.03 KiB
    /// @file Trainer.hpp
    /// @author Franck Dary
    /// @version 1.0
    /// @date 2018-08-03
    
    #ifndef TRAINER__H
    #define TRAINER__H
    
    #include "TransitionMachine.hpp"
    #include "BD.hpp"
    #include "Config.hpp"
    
    /// @brief An object capable of training a TransitionMachine given a BD initialized with training examples.
    class Trainer
    {
      private :
    
      /// @brief The TransitionMachine that will be trained.
      TransitionMachine & tm;
      /// @brief The BD initialized with training examples.
      BD & trainBD;
      /// @brief The configuration of the TransitionMachine while processing trainBD.
      Config & trainConfig;
    
      /// @brief The BD initialized with dev examples.
      ///
      /// Can be nullptr if dev is not used in this training.
      BD * devBD;
      /// @brief The configuration of the TransitionMachine while processing devBD.
      ///
      /// Can be nullptr if dev is not used in this training.
      Config * devConfig;
    
      /// @brief For each classifier, a pair of number examples seen / number examples successfully classified
      std::map< std::string, std::pair<int, int> > trainCounter;
    
      /// @brief For each classifier, the train score for the current iteration.
      std::map< std::string, float > scores;
    
      /// @brief For each classifier, the best score seen on dev.
      std::map< std::string, float > bestScores;
    
      /// @brief Whether or not each Classifier topology has been printed.
      std::map< std::string, bool > topologyPrinted;
    
      /// @brief Current iteration.
      int curIter;
    
      public :
    
      /// @brief The FeatureDescritpion of a Config.
      using FD = FeatureModel::FeatureDescription;
    
      private :
    
      /// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score.
      void printScoresAndSave(FILE * output);
    
      /// @brief Get the scores of the classifiers on the dev dataset.
      ///
      /// @return Map from each Classifier name to their score.
      std::map<std::string, float> getScoreOnDev();
    
      public :
    
      /// @brief Construct a new Trainer without a dev set.
      ///
      /// @param tm The TransitionMachine to use.
      /// @param bd The BD to use.
      /// @param config The config to use.
      Trainer(TransitionMachine & tm, BD & bd, Config & config);
      /// @brief Construct a new Trainer with a dev set.
      ///
      /// @param tm The TransitionMachine to use.
      /// @param bd The BD corresponding to the training dataset.
      /// @param config The Config corresponding to bd.
      /// @param devBD The BD corresponding to the dev dataset.
      /// @param devConfig The Config corresponding to devBD.
      Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig);
    
      /// @brief Train the TransitionMachine one example at a time.
      ///
      /// For each epoch all the Classifier of the TransitionMachine are fed all the 
      /// training examples, at the end of the epoch Classifier are evaluated on 
      /// the devBD if available, and each Classifier will be saved only if its score
      /// on the current epoch is its all time best.\n
      /// When a Classifier is saved that way, all the Dict involved are also saved.
      void train();
    };
    
    #endif