-
Franck Dary authoredFranck Dary authored
Trainer.hpp 950 B
#ifndef TRAINER__H
#define TRAINER__H
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
class Trainer
{
private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
private :
ReadingMachine & machine;
DataLoader dataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{50};
int nbExamples{0};
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
};
#endif