/// @file Trainer.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2018-08-03

#ifndef TRAINER__H
#define TRAINER__H

#include "TapeMachine.hpp"
#include "BD.hpp"
#include "Config.hpp"

/// @brief An object capable of training a TapeMachine given a BD initialized with training examples.
class Trainer
{
  public :

  /// @brief The absolute path in which this experience (training) is taking place
  std::string expPath;

  private :

  /// @brief The TapeMachine that will be trained.
  TapeMachine & tm;
  /// @brief The BD initialized with training examples.
  BD & trainBD;
  /// @brief The configuration of the TapeMachine while processing trainBD.
  Config & trainConfig;

  /// @brief The BD initialized with dev examples.
  /// Can be nullptr if dev is not used in this training.
  BD * devBD;
  /// @brief The configuration of the TapeMachine while processing devBD.
  // Can be nullptr if dev is not used in this training.
  Config * devConfig;

  public :

  using FD = FeatureModel::FeatureDescription;
  using Example = std::pair<int, FD>;
  using ExamplesIter = std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator;

  private :

  /// @brief Train the TapeMachine using batches of examples.
  /// For each epoch all the Classifier of the TapeMachine are fed all the 
  /// training examples, at the end of the epoch Classifier are evaluated on 
  /// the devBD if available, and each Classifier will be saved only if its score
  /// on the current epoch is its all time best.\n
  /// When a Classifier is saved that way, all the Dict involved are also saved.
  /// @param nbIter The number of epochs.
  /// @param batchSize The size of each batch (in number of examples).
  /// @param mustShuffle Will the examples be shuffled after every epoch ?
  void trainBatched(int nbIter, int batchSize, bool mustShuffle);
  /// @brief Extract training examples for all Classifier
  ///
  /// @param examples The map that will be filled by this function.
  /// @param config The configuration from which the examples will be extracted.
  void getExamplesByClassifier(std::map<Classifier*, MLP::Examples> & examples, Config & config);

  /// @brief Make each Classifier go over every examples.
  /// Depending on getScoreOnBatch, it can update the parameters or not.
  /// @param examples Map each trainable Classifier with a set of examples.
  /// @param batchSize The batch size to use.
  /// @param nbExamples Map each trainable Classifier to a count of how many examples it has seen during this epoch and a count of how many of this examples it has correctly classified. This map is filled by this function.
  /// @param getScoreOnBatch The MLP function that must be called to get the score of a classifier on a certain batch.
  void processAllExamples(
    std::map<Classifier*, MLP::Examples> & examples,
    int batchSize, std::map< std::string, std::pair<int, int> > & nbExamples,
    std::function<int(Classifier *, MLP::Examples &, int, int)> getScoreOnBatch);

  /// @brief Print the score obtained by all Classifier on this epoch.
  ///
  /// @param output Where to print the output.
  /// @param nbExamplesTrain Map each trainable Classifier to a count of how many train examples it has seen during this epoch and a count of how many of this examples it has correctly classified.
  /// @param nbExamplesDev Map each trainable Classifier to a count of how many dev examples it has seen during this epoch and a count of how many of this examples it has correctly classified.
  /// @param trainScores The scores obtained by each Classifier on the train set.
  /// @param devScores The scores obtained by each Classifier on the train set.
  /// @param bestIter Map each classifier to its best epoch. It is updated by this function.
  /// @param nbIter The total number of epoch of the training.
  /// @param curIter The current epoch of the training.
  void printIterationScores(FILE * output,
    std::map< std::string, std::pair<int, int> > & nbExamplesTrain,
    std::map< std::string, std::pair<int, int> > & nbExamplesDev,
    std::map< std::string, std::vector<float> > & trainScores,
    std::map< std::string, std::vector<float> > & devScores,
    std::map<std::string, int> & bestIter,
    int nbIter, int curIter);

  /// @brief For every Classifier, shuffle its training examples.
  ///
  /// @param examples Map each Classifier to a set of training examples.
  void shuffleAllExamples(std::map<Classifier*, MLP::Examples > & examples);

  public :

  /// @brief Construct a new Trainer without a dev set.
  ///
  /// @param tm The TapeMachine to use.
  /// @param bd The BD to use.
  /// @param config The config to use.
  Trainer(TapeMachine & tm, BD & bd, Config & config);
  /// @brief Construct a new Trainer with a dev set.
  ///
  /// @param tm The TapeMachine to use.
  /// @param bd The BD corresponding to the training dataset.
  /// @param config The Config corresponding to bd.
  /// @param devBD The BD corresponding to the dev dataset.
  /// @param devConfig The Config corresponding to devBD.
  Trainer(TapeMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig);
  /// @brief Train the TapeMachine.
  ///
  /// @param nbIter The number of training epochs.
  /// @param batchSize The size of each batch.
  /// @param mustShuffle Will the examples be shuffled after every epoch ?
  void train(int nbIter, int batchSize, bool mustShuffle);
};

#endif