Select Git revision
Trainer.cpp
Trainer.cpp 15.17 KiB
#include "Trainer.hpp"
#include "SubConfig.hpp"
Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
}
void Trainer::makeDataLoader(std::filesystem::path dir)
{
trainDataset.reset(new Dataset(dir));
dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
devDataset.reset(new Dataset(dir));
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
std::vector<SubConfig> configs;
for (auto & goldConfig : goldConfigs)
configs.emplace_back(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold);
machine.saveDicts();
}
void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
torch::AutoGradMode useGrad(false);
int maxNbExamplesPerFile = 50000;
std::map<std::string, Examples> examplesPerState;
std::filesystem::create_directories(dir);
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
if (std::filesystem::exists(currentEpochAllExtractedFile))
return;
fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
int totalNbExamples = 0;
for (auto & config : configs)
{
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier(config.getState())->setState(config.getState());
while (true)
{
if (debug)
config.printForDebug(stderr);
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
std::vector<std::vector<long>> context;