#include "Trainer.hpp" #include "SubConfig.hpp" #include <execution> Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { } void Trainer::makeDataLoader(std::filesystem::path dir) { trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::makeDevDataLoader(std::filesystem::path dir) { devDataset.reset(new Dataset(dir)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck) { std::vector<SubConfig> configs; for (auto & goldConfig : goldConfigs) configs.emplace_back(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold, memcheck); machine.saveDicts(); } void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck) { torch::AutoGradMode useGrad(false); int maxNbExamplesPerFile = 50000; std::unordered_map<std::string, Examples> examplesPerState; std::mutex examplesMutex; std::filesystem::create_directories(dir); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); if (std::filesystem::exists(currentEpochAllExtractedFile)) return; fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); std::atomic<int> totalNbExamples = 0; if (memcheck) fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage()); NeuralNetworkImpl::setDevice(torch::kCPU); machine.to(NeuralNetworkImpl::getDevice()); std::for_each(std::execution::seq, configs.begin(), configs.end(), [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, memcheck, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config) { config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); while (true) { if (debug) config.printForDebug(stderr); if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); torch::Tensor context; try { context = machine.getClassifier(config.getState())->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } Transition * transition = nullptr; auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); Transition * goldTransition = goldTransitions[0]; if (config.getState() == "parser") goldTransitions[std::rand()%goldTransitions.size()]; int nbClasses = machine.getTransitionSet(config.getState()).size(); float bestScore = -std::numeric_limits<float>::max(); float entropy = 0.0; if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto & classifier = *machine.getClassifier(config.getState()); auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates; for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if (score > bestScore and appliableTransitions[i]) bestScore = score; } for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if (appliableTransitions[i] and bestScore - score <= explorationThreshold) candidates.emplace_back(i); } transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); for (auto & trans : goldTransitions) if (trans == transition) goldTransition = trans; } else { transition = goldTransition; } if (!transition or !goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } std::vector<long> goldIndexes; bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); if (machine.getClassifier(config.getState())->isRegression()) { entropy = 0.0; auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); auto splited = util::split(transition->getName(), ' '); if (splited.size() != 3 or splited[0] != "WRITESCORE") util::myThrow(errMessage); auto col = splited[2]; splited = util::split(splited[1], '.'); if (splited.size() != 2) util::myThrow(errMessage); auto object = Config::str2object(splited[0]); int index = std::stoi(splited[1]); float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); goldIndexes.emplace_back(util::float2long(regressionTarget)); } else { for (auto & t : goldTransitions) goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); } if (!exampleIsBanned) { totalNbExamples += 1; if (totalNbExamples >= (int)safetyNbExamplesMax) util::error(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); examplesMutex.lock(); examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); examplesMutex.unlock(); } config.setChosenActionScore(bestScore); transition->apply(config, entropy); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } // End while true if (memcheck) fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage()); }); // End for on configs for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice()); machine.to(NeuralNetworkImpl::getDevice()); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); if (memcheck) fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage()); fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; int totalNbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; torch::AutoGradMode useGrad(train); machine.trainMode(train); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { auto data = std::get<0>(batch); auto labels = std::get<1>(batch); auto state = std::get<2>(batch); if (train) machine.getClassifier(state)->getOptimizer().zero_grad(); auto prediction = machine.getClassifier(state)->getNN()->forward(data, state); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); if (machine.getClassifier(state)->isRegression()) { labels = labels.to(torch::kFloat); labels /= util::float2longScale; } auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state); auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter; float lossAsFloat = 0.0; try { lossAsFloat = loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} totalLoss += lossAsFloat; lossSoFar += lossAsFloat; if (train) { loss.backward(); machine.getClassifier(state)->getOptimizer().step(); } totalNbExamplesProcessed += labels.size(0); if (printAdvancement) { nbExamplesProcessed += labels.size(0); if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; auto speed = (int)(nbExamplesProcessed/seconds); auto progression = 100.0*totalNbExamplesProcessed / nbExamples; auto statusStr = fmt::format(lossSoFar/nbExamplesProcessed < 10.0 ? "{:6.2f}% loss={:<7.3f} speed={:<6}ex/s": "{:6.2f}% loss={:<7.0f} speed={:<6}ex/s", progression, lossSoFar / nbExamplesProcessed, speed); if (train) fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr); else fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss / nbExamples; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value()); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle) { if (currentExampleIndex-lastSavedIndex < (int)threshold) return; if (contexts.empty()) return; int nbClasses = classes[0].size(0); auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); classes.clear(); } void Trainer::Examples::addContext(torch::Tensor & context) { contexts.emplace_back(context); currentExampleIndex += 1; } void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes) { auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes); while (classes.size() < contexts.size()) classes.emplace_back(gold); } Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) { if (s == "ExtractGold") return TrainAction::ExtractGold; if (s == "ExtractDynamic") return TrainAction::ExtractDynamic; if (s == "DeleteExamples") return TrainAction::DeleteExamples; if (s == "ResetOptimizer") return TrainAction::ResetOptimizer; if (s == "ResetParameters") return TrainAction::ResetParameters; if (s == "Save") return TrainAction::Save; util::myThrow(fmt::format("unknown TrainAction '{}'", s)); return TrainAction::ExtractGold; } void Trainer::extractActionSequence(BaseConfig & config) { config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); int curSeq = 0; int curSeqStartIndex = -1; int curInputIndex = 0; int curInputSeqSize = 0; int curOutputSeqSize = 0; int maxInputSeqSize = 0; int maxOutputSeqSize = 0; bool newSent = true; std::vector<std::string> transitionsIndexes; while (true) { if (config.hasCharacter(0)) curInputIndex = config.getCharacterIndex(); else curInputIndex = config.getWordIndex(); if (curSeqStartIndex == -1 or newSent) { newSent = false; curSeqStartIndex = curInputIndex; } if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true); Transition * transition = goldTransitions[0]; if (machine.getClassifier(config.getState())->isRegression()) util::myThrow("Regressions are not supported in extract action sequence mode"); transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition))); maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++); curInputSeqSize = -curSeqStartIndex + curInputIndex; maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++); if (util::split(transition->getName(), ' ')[0] == "EOS") if (++curSeq % 3 == 0) { newSent = true; std::string curSeq = ""; for (int i = curSeqStartIndex; i <= curInputIndex; i++) curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i))); fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); curOutputSeqSize = 0; curInputSeqSize = 0; transitionsIndexes.clear(); } transition->apply(config); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName()); if (movement == Strategy::endMovement) break; config.setState(movement.first); config.moveWordIndexRelaxed(movement.second); } if (curSeqStartIndex != curInputIndex) { std::string curSeq = ""; for (int i = curSeqStartIndex; i <= curInputIndex; i++) curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i))); fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes)); curOutputSeqSize = 0; curInputSeqSize = 0; curSeqStartIndex = curInputIndex; } fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize); fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize); }