#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { } void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(true); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); devDataset.reset(new Dataset(dir)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir) { auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); classes.clear(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000; int currentExampleIndex = 0; int lastSavedIndex = 0; std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); machine.getStrategy().reset(); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval)) mustExtract = false; if (!mustExtract) return; bool dynamicOracle = epoch != 0; fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) std::filesystem::remove(entry.path()); while (true) { if (debug) config.printForDebug(stderr); if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); std::vector<std::vector<long>> context; try { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); for (auto & element : context) contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } Transition * transition = nullptr; if (dynamicOracle and config.getState() != "tokenizer") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) { chosenTransition = i; bestScore = score; } } transition = machine.getTransitionSet().getTransition(chosenTransition); } else { transition = machine.getTransitionSet().getBestAppliableTransition(config); } if (!transition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; currentExampleIndex += context.size(); classes.insert(classes.end(), context.size(), gold); if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } if (!contexts.empty()) saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); machine.saveDicts(); fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; int totalNbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; torch::AutoGradMode useGrad(train); machine.trainMode(train); machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { if (train) machine.getClassifier()->getOptimizer().zero_grad(); auto data = batch.first; auto labels = batch.second; auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); auto loss = lossFct(prediction, labels); try { totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} if (train) { loss.backward(); machine.getClassifier()->getOptimizer().step(); } totalNbExamplesProcessed += torch::numel(labels); if (printAdvancement) { nbExamplesProcessed += torch::numel(labels); if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; auto speed = (int)(nbExamplesProcessed/seconds); auto progression = 100.0*totalNbExamplesProcessed / nbExamples; auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed); if (train) fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr); else fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss / nbExamples; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value()); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); }