#ifndef TRAINER__H #define TRAINER__H #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" class Trainer { private : using Dataset = ConfigDataset; using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >; private : ReadingMachine & machine; DataLoader dataLoader{nullptr}; DataLoader devDataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize{64}; int nbExamples{0}; private : void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes); float processDataset(DataLoader & loader, bool train, bool printAdvancement); public : Trainer(ReadingMachine & machine); void createDataset(SubConfig & goldConfig, bool debug); void createDevDataset(SubConfig & goldConfig, bool debug); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); }; #endif