From b495167ca9c4db71faed2f18d5d9c41903dd2522 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 6 Mar 2021 21:46:02 +0100 Subject: [PATCH] Parallel extractExamples --- common/include/Dict.hpp | 3 + common/src/Dict.cpp | 21 +++- trainer/src/Trainer.cpp | 226 +++++++++++++++++++++------------------- 3 files changed, 138 insertions(+), 112 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 5da9154..7ff6e01 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -5,6 +5,7 @@ #include <unordered_map> #include <vector> #include <filesystem> +#include <mutex> class Dict { @@ -30,6 +31,7 @@ class Dict std::unordered_map<std::string, int> elementsToIndexes; std::unordered_map<int, std::string> indexesToElements; std::vector<int> nbOccs; + std::mutex elementsMutex; State state; bool isCountingOccs{false}; @@ -43,6 +45,7 @@ class Dict void readFromFile(const char * filename); void insert(const std::string & element); void reset(); + int _getIndexOrInsert(const std::string & element, const std::string & prefix); public : diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 882c989..b1de43a 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -90,20 +90,33 @@ void Dict::insert(const std::string & element) } int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix) +{ + if (state == State::Open) + elementsMutex.lock(); + + int index = _getIndexOrInsert(element, prefix); + + if (state == State::Open) + elementsMutex.unlock(); + + return index; +} + +int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix) { if (element.empty()) - return getIndexOrInsert(emptyValueStr, prefix); + return _getIndexOrInsert(emptyValueStr, prefix); if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element))) { - return getIndexOrInsert(separatorValueStr, prefix); + return _getIndexOrInsert(separatorValueStr, prefix); } if (util::isNumber(element)) - return getIndexOrInsert(numberValueStr, prefix); + return _getIndexOrInsert(numberValueStr, prefix); if (util::isUrl(element)) - return getIndexOrInsert(urlValueStr, prefix); + return _getIndexOrInsert(urlValueStr, prefix); auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); const auto & found = elementsToIndexes.find(prefixed); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 298e2a9..6c490bd 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,5 +1,6 @@ #include "Trainer.hpp" #include "SubConfig.hpp" +#include <execution> Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { @@ -35,7 +36,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: torch::AutoGradMode useGrad(false); int maxNbExamplesPerFile = 50000; - std::map<std::string, Examples> examplesPerState; + std::unordered_map<std::string, Examples> examplesPerState; + std::mutex examplesMutex; std::filesystem::create_directories(dir); @@ -46,144 +48,152 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); - int totalNbExamples = 0; + std::atomic<int> totalNbExamples = 0; - for (auto & config : configs) - { - config.addPredicted(machine.getPredicted()); - config.setStrategy(machine.getStrategyDefinition()); - config.setState(config.getStrategy().getInitialState()); - - while (true) + NeuralNetworkImpl::device = torch::kCPU; + machine.to(NeuralNetworkImpl::device); + std::for_each(std::execution::par_unseq, configs.begin(), configs.end(), + [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config) { - if (debug) - config.printForDebug(stderr); + config.addPredicted(machine.getPredicted()); + config.setStrategy(machine.getStrategyDefinition()); + config.setState(config.getStrategy().getInitialState()); - if (machine.hasSplitWordTransitionSet()) - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + while (true) + { + if (debug) + config.printForDebug(stderr); - auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); - config.setAppliableTransitions(appliableTransitions); + if (machine.hasSplitWordTransitionSet()) + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - torch::Tensor context; + auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); + config.setAppliableTransitions(appliableTransitions); - try - { - context = machine.getClassifier(config.getState())->getNN()->extractContext(config); - } catch(std::exception & e) - { - util::myThrow(fmt::format("Failed to extract context : {}", e.what())); - } + torch::Tensor context; - Transition * transition = nullptr; + try + { + context = machine.getClassifier(config.getState())->getNN()->extractContext(config); + } catch(std::exception & e) + { + util::myThrow(fmt::format("Failed to extract context : {}", e.what())); + } - auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); + Transition * transition = nullptr; - Transition * goldTransition = goldTransitions[0]; - if (config.getState() == "parser") - goldTransitions[std::rand()%goldTransitions.size()]; + auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); - int nbClasses = machine.getTransitionSet(config.getState()).size(); + Transition * goldTransition = goldTransitions[0]; + if (config.getState() == "parser") + goldTransitions[std::rand()%goldTransitions.size()]; - float bestScore = -std::numeric_limits<float>::max(); + int nbClasses = machine.getTransitionSet(config.getState()).size(); - float entropy = 0.0; - - if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") - { - auto & classifier = *machine.getClassifier(config.getState()); - auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0); - entropy = NeuralNetworkImpl::entropy(prediction); - - std::vector<int> candidates; + float bestScore = -std::numeric_limits<float>::max(); - for (unsigned int i = 0; i < prediction.size(0); i++) + float entropy = 0.0; + + if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { - float score = prediction[i].item<float>(); - if (score > bestScore and appliableTransitions[i]) - bestScore = score; + auto & classifier = *machine.getClassifier(config.getState()); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0); + entropy = NeuralNetworkImpl::entropy(prediction); + + std::vector<int> candidates; + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (score > bestScore and appliableTransitions[i]) + bestScore = score; + } + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (appliableTransitions[i] and bestScore - score <= explorationThreshold) + candidates.emplace_back(i); + } + + transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); } - - for (unsigned int i = 0; i < prediction.size(0); i++) + else { - float score = prediction[i].item<float>(); - if (appliableTransitions[i] and bestScore - score <= explorationThreshold) - candidates.emplace_back(i); + transition = goldTransition; } - transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); - } - else - { - transition = goldTransition; - } - - if (!transition or !goldTransition) - { - config.printForDebug(stderr); - util::myThrow("No transition appliable !"); - } + if (!transition or !goldTransition) + { + config.printForDebug(stderr); + util::myThrow("No transition appliable !"); + } - std::vector<long> goldIndexes; - bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); + std::vector<long> goldIndexes; + bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); - if (machine.getClassifier(config.getState())->isRegression()) - { - entropy = 0.0; - auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); - auto splited = util::split(transition->getName(), ' '); - if (splited.size() != 3 or splited[0] != "WRITESCORE") - util::myThrow(errMessage); - auto col = splited[2]; - splited = util::split(splited[1], '.'); - if (splited.size() != 2) - util::myThrow(errMessage); - auto object = Config::str2object(splited[0]); - int index = std::stoi(splited[1]); - - float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); - goldIndexes.emplace_back(util::float2long(regressionTarget)); - } - else - { - for (auto & t : goldTransitions) - goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); - - } + if (machine.getClassifier(config.getState())->isRegression()) + { + entropy = 0.0; + auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); + auto splited = util::split(transition->getName(), ' '); + if (splited.size() != 3 or splited[0] != "WRITESCORE") + util::myThrow(errMessage); + auto col = splited[2]; + splited = util::split(splited[1], '.'); + if (splited.size() != 2) + util::myThrow(errMessage); + auto object = Config::str2object(splited[0]); + int index = std::stoi(splited[1]); + + float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); + goldIndexes.emplace_back(util::float2long(regressionTarget)); + } + else + { + for (auto & t : goldTransitions) + goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); - if (!exampleIsBanned) - { - totalNbExamples += 1; - if (totalNbExamples >= (int)safetyNbExamplesMax) - util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); + } - examplesPerState[config.getState()].addContext(context); - examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); - examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); - } + if (!exampleIsBanned) + { + totalNbExamples += 1; + if (totalNbExamples >= (int)safetyNbExamplesMax) + util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); + + examplesMutex.lock(); + examplesPerState[config.getState()].addContext(context); + examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); + examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); + examplesMutex.unlock(); + } - config.setChosenActionScore(bestScore); + config.setChosenActionScore(bestScore); - transition->apply(config, entropy); - config.addToHistory(transition->getName()); + transition->apply(config, entropy); + config.addToHistory(transition->getName()); - auto movement = config.getStrategy().getMovement(config, transition->getName()); - if (debug) - fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); - if (movement == Strategy::endMovement) - break; + auto movement = config.getStrategy().getMovement(config, transition->getName()); + if (debug) + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); + if (movement == Strategy::endMovement) + break; - config.setState(movement.first); - config.moveWordIndexRelaxed(movement.second); + config.setState(movement.first); + config.moveWordIndexRelaxed(movement.second); - if (config.needsUpdate()) - config.update(); - } // End while true - } // End for on configs + if (config.needsUpdate()) + config.update(); + } // End while true + }); // End for on configs for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); + NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice(); + machine.to(NeuralNetworkImpl::device); + std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); -- GitLab