-
Franck Dary authoredFranck Dary authored
Trainer.hpp 2.02 KiB
#ifndef TRAINER__H
#define TRAINER__H
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
class Trainer
{
public :
enum TrainAction
{
ExtractGold,
ExtractDynamic,
DeleteExamples,
ResetOptimizer,
ResetParameters,
Save
};
using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>;
static TrainAction str2TrainAction(const std::string & s);
private :
static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000;
struct Examples
{
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
int currentExampleIndex{0};
int lastSavedIndex{0};
void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
void addContext(std::vector<std::vector<long>> & context);
void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes);
};
private :
using Dataset = ConfigDataset;
using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>;
private :
ReadingMachine & machine;
std::unique_ptr<Dataset> trainDataset{nullptr};
std::unique_ptr<Dataset> devDataset{nullptr};
DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr};
std::size_t epochNumber{0};
int batchSize;
private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
public :
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
void extractActionSequence(BaseConfig & config);
void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
};