#ifndef TRAINER__H #define TRAINER__H #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" class Trainer { public : enum TrainAction { ExtractGold, ExtractDynamic, DeleteExamples, ResetOptimizer, ResetParameters, Save }; using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>; static TrainAction str2TrainAction(const std::string & s); private : static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000; struct Examples { std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; int currentExampleIndex{0}; int lastSavedIndex{0}; void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle); void addContext(std::vector<std::vector<long>> & context); void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes); }; private : using Dataset = ConfigDataset; using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>; private : ReadingMachine & machine; std::unique_ptr<Dataset> trainDataset{nullptr}; std::unique_ptr<Dataset> devDataset{nullptr}; DataLoader dataLoader{nullptr}; DataLoader devDataLoader{nullptr}; std::size_t epochNumber{0}; int batchSize; private : void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); public : Trainer(ReadingMachine & machine, int batchSize); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); void extractActionSequence(BaseConfig & config); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); }; #endif