#ifndef TRAINER__H
#define TRAINER__H

#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"

class Trainer
{
  private :

  using Dataset = ConfigDataset;
  using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;

  private :

  ReadingMachine & machine;
  DataLoader dataLoader{nullptr};
  DataLoader devDataLoader{nullptr};
  std::unique_ptr<torch::optim::Adam> optimizer;
  std::size_t epochNumber{0};
  int batchSize{64};
  int nbExamples{0};

  private :

  void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes);
  float processDataset(DataLoader & loader, bool train, bool printAdvancement);

  public :

  Trainer(ReadingMachine & machine);
  void createDataset(SubConfig & goldConfig, bool debug);
  void createDevDataset(SubConfig & goldConfig, bool debug);
  float epoch(bool printAdvancement);
  float evalOnDev(bool printAdvancement);
};

#endif