#include "Trainer.hpp" #include "SubConfig.hpp" Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { } void Trainer::makeDataLoader(std::filesystem::path dir) { trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::makeDevDataLoader(std::filesystem::path dir) { devDataset.reset(new Dataset(dir)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); extractExamples(config, debug, dir, epoch, dynamicOracle); machine.saveDicts(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) { torch::AutoGradMode useGrad(false); int maxNbExamplesPerFile = 50000; std::map<std::string, Examples> examplesPerState; std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState()); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); if (std::filesystem::exists(currentEpochAllExtractedFile)) return; fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); int totalNbExamples = 0; while (true) { if (debug) config.printForDebug(stderr); if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); std::vector<std::vector<long>> context; try { context = machine.getClassifier()->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } Transition * transition = nullptr; Transition * goldTransition = nullptr; goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) { chosenTransition = i; bestScore = score; } } transition = machine.getTransitionSet().getTransition(chosenTransition); } else { transition = goldTransition; } if (!transition or !goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition); totalNbExamples += context.size(); if (totalNbExamples >= (int)safetyNbExamplesMax) util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(goldIndex); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); transition->apply(config); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; int totalNbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; torch::AutoGradMode useGrad(train); machine.trainMode(train); machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { if (train) machine.getClassifier()->getOptimizer().zero_grad(); auto data = std::get<0>(batch); auto labels = std::get<1>(batch); auto state = std::get<2>(batch); machine.getClassifier()->setState(state); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels); try { totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} if (train) { loss.backward(); machine.getClassifier()->getOptimizer().step(); } totalNbExamplesProcessed += torch::numel(labels); if (printAdvancement) { nbExamplesProcessed += torch::numel(labels); if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; auto speed = (int)(nbExamplesProcessed/seconds); auto progression = 100.0*totalNbExamplesProcessed / nbExamples; auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed); if (train) fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr); else fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss / nbExamples; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value()); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle) { if (currentExampleIndex-lastSavedIndex < (int)threshold) return; if (contexts.empty()) return; auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); classes.clear(); } void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) { for (auto & element : context) contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); currentExampleIndex += context.size(); } void Trainer::Examples::addClass(int goldIndex) { auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; while (classes.size() < contexts.size()) classes.emplace_back(gold); } void Trainer::fillDicts(SubConfig & config, bool debug) { torch::AutoGradMode useGrad(false); config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState()); while (true) { if (debug) config.printForDebug(stderr); if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); try { machine.getClassifier()->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } Transition * goldTransition = nullptr; goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions); if (!goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } goldTransition->apply(config); config.addToHistory(goldTransition->getName()); auto movement = config.getStrategy().getMovement(config, goldTransition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } } Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) { if (s == "ExtractGold") return TrainAction::ExtractGold; if (s == "ExtractDynamic") return TrainAction::ExtractDynamic; if (s == "DeleteExamples") return TrainAction::DeleteExamples; if (s == "ResetOptimizer") return TrainAction::ResetOptimizer; if (s == "ResetParameters") return TrainAction::ResetParameters; if (s == "Save") return TrainAction::Save; util::myThrow(fmt::format("unknown TrainAction '{}'", s)); return TrainAction::ExtractGold; }