#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine) : machine(machine) { } void Trainer::createDataset(SubConfig & config, bool debug) { machine.trainMode(true); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; extractExamples(config, debug, contexts, classes); nbExamples = classes.size(); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } void Trainer::createDevDataset(SubConfig & config, bool debug) { machine.trainMode(false); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; extractExamples(config, debug, contexts, classes); devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes) { fmt::print(stderr, "[{}] Starting to extract examples\n", util::getTime()); config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); while (true) { if (debug) config.printForDebug(stderr); config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } std::vector<std::vector<long>> context; try { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); for (auto & element : context) contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; for (auto & element : context) classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) { config.printForDebug(stderr); util::myThrow(fmt::format("Cannot move word index by {}", movement.second)); } if (config.needsUpdate()) config.update(); } fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(classes.size())); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); machine.trainMode(train); auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { if (train) optimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); auto loss = lossFct(prediction, labels); try { totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} if (train) { loss.backward(); optimizer->step(); } if (printAdvancement) { nbExamplesProcessed += labels.size(0); ++currentBatchNumber; if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; if (train) fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds)); else fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds)); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement); } void Trainer::loadOptimizer(std::filesystem::path path) { torch::load(*optimizer, path); } void Trainer::saveOptimizer(std::filesystem::path path) { torch::save(*optimizer, path); }