#ifndef TRAINER__H #define TRAINER__H #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" class Trainer { private : using Dataset = ConfigDataset; using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >; private : ReadingMachine & machine; DataLoader dataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize{1}; int nbExamples{0}; public : Trainer(ReadingMachine & machine); void createDataset(SubConfig & goldConfig, bool debug); float epoch(bool printAdvancement); }; #endif