Newer
Older
#ifndef TRAINER__H
#define TRAINER__H
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
class Trainer
{
private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
private :
Franck Dary
committed
std::unique_ptr<torch::optim::Adam> optimizer;
void createDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);