From 9e3b06af5ee5ba3280c7408ef4d66f6252429ec6 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 25 May 2020 13:34:20 +0200 Subject: [PATCH] Introduced trainStrategy --- decoder/include/Beam.hpp | 5 +- decoder/src/Beam.cpp | 20 +++-- reading_machine/include/ReadingMachine.hpp | 3 + reading_machine/src/ReadingMachine.cpp | 9 ++- torch_modules/src/ConfigDataset.cpp | 2 +- trainer/include/MacaonTrain.hpp | 4 + trainer/include/Trainer.hpp | 23 +++++- trainer/src/MacaonTrain.cpp | 86 +++++++++++++++++----- trainer/src/Trainer.cpp | 67 +++++++++-------- 9 files changed, 153 insertions(+), 66 deletions(-) diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp index 4153460..2c34d3f 100644 --- a/decoder/include/Beam.hpp +++ b/decoder/include/Beam.hpp @@ -16,14 +16,15 @@ class Beam BaseConfig config; int nextTransition{-1}; - boost::circular_buffer<double> probabilities{500}; boost::circular_buffer<std::string> name{20}; float meanProbability{0.0}; + int nbTransitions = 0; + double totalProbability{0.0}; bool ended{false}; public : - Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name); + Element(const Element & other, int nextTransition); Element(const BaseConfig & model); }; diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 3af7ab0..47afb72 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -8,8 +8,9 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading elements.emplace_back(model); } -Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name) +Beam::Element::Element(const Element & other, int nextTransition) : Element(other) { + this->nextTransition = nextTransition; } Beam::Element::Element(const BaseConfig & model) : config(model) @@ -71,22 +72,19 @@ void Beam::update(ReadingMachine & machine, bool debug) if (width > 1) for (unsigned int i = 1; i < scoresOfTransitions.size(); i++) { - elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name); + elements.emplace_back(elements[index], scoresOfTransitions[i].second); elements.back().name.push_back(std::to_string(i)); - elements.back().probabilities.push_back(scoresOfTransitions[i].first); - elements.back().meanProbability = 0.0; - for (auto & p : elements.back().probabilities) - elements.back().meanProbability += p; - elements.back().meanProbability /= elements.back().probabilities.size(); + elements.back().totalProbability += scoresOfTransitions[i].first; + elements.back().nbTransitions++; + elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions; } elements[index].nextTransition = scoresOfTransitions[0].second; - elements[index].probabilities.push_back(scoresOfTransitions[0].first); + elements[index].totalProbability += scoresOfTransitions[0].first; + elements[index].nbTransitions++; elements[index].name.push_back("0"); elements[index].meanProbability = 0.0; - for (auto & p : elements[index].probabilities) - elements[index].meanProbability += p; - elements[index].meanProbability /= elements[index].probabilities.size(); + elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions; if (debug) { diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index d4b419a..13f5cbc 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -19,6 +19,8 @@ class ReadingMachine std::filesystem::path path; std::unique_ptr<Classifier> classifier; std::vector<std::string> strategyDefinition; + std::vector<std::string> classifierDefinition; + std::string classifierName; std::set<std::string> predicted; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; @@ -48,6 +50,7 @@ class ReadingMachine void loadLastSaved(); void setCountOcc(bool countOcc); void removeRareDictElements(float rarityThreshold); + void resetClassifier(); }; #endif diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index e48d507..973d680 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -58,7 +58,8 @@ void ReadingMachine::readFromFile(std::filesystem::path path) while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm) { - std::vector<std::string> classifierDefinition; + classifierDefinition.clear(); + classifierName = sm.str(1); if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); @@ -196,3 +197,9 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold) classifier->getNN()->removeRareDictElements(rarityThreshold); } +void ReadingMachine::resetClassifier() +{ + classifier.reset(new Classifier(classifierName, path, classifierDefinition)); + loadDicts(); +} + diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 30dce0e..2022e52 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -6,7 +6,7 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) { - auto stem = entry.path().stem().string(); + auto stem = util::split(entry.path().stem().string(), '.')[0]; if (stem == "extracted") continue; auto state = util::split(stem, '_')[0]; diff --git a/trainer/include/MacaonTrain.hpp b/trainer/include/MacaonTrain.hpp index 9a92664..731dea3 100644 --- a/trainer/include/MacaonTrain.hpp +++ b/trainer/include/MacaonTrain.hpp @@ -20,6 +20,10 @@ class MacaonTrain po::options_description getOptionsDescription(); po::variables_map checkOptions(po::options_description & od); + private : + + Trainer::TrainStrategy parseTrainStrategy(std::string s); + public : MacaonTrain(int argc, char ** argv); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 29485d2..fcbae07 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -7,6 +7,20 @@ class Trainer { + public : + + enum TrainAction + { + ExtractGold, + ExtractDynamic, + DeleteExamples, + ResetOptimizer, + ResetParameters, + Save + }; + using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>; + static TrainAction str2TrainAction(const std::string & s); + private : static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000; @@ -19,7 +33,7 @@ class Trainer int currentExampleIndex{0}; int lastSavedIndex{0}; - void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold); + void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle); void addContext(std::vector<std::vector<long>> & context); void addClass(int goldIndex); }; @@ -41,15 +55,16 @@ class Trainer private : - void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); + void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); void fillDicts(SubConfig & config, bool debug); public : Trainer(ReadingMachine & machine, int batchSize); - void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); - void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); + void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); + void makeDataLoader(std::filesystem::path dir); + void makeDevDataLoader(std::filesystem::path dir); void fillDicts(BaseConfig & goldConfig, bool debug); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index f7d0a39..5de8971 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -33,12 +33,12 @@ po::options_description MacaonTrain::getOptionsDescription() "Number of training epochs") ("batchSize", po::value<int>()->default_value(64), "Number of examples per batch") - ("dynamicOracleInterval", po::value<int>()->default_value(-1), - "Every X epochs, the machine will be used to decode the train and dev corpora. Thus allowing the machine to train using it's own predictions as feature. A value of -1 means the machine will always train on GOLD features. This option slows training down by a LOT.") ("rarityThreshold", po::value<float>()->default_value(70.0), "During train, the X% rarest elements will be treated as unknown values") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") + ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold"), + "Description of what should happen during training") ("pretrainedEmbeddings", po::value<std::string>()->default_value(""), "File containing pretrained embeddings, w2v format") ("help,h", "Produce this help message"); @@ -69,6 +69,27 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od) return vm; } +Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s) +{ + Trainer::TrainStrategy ts; + + try + { + auto splited = util::split(s, ':'); + for (auto & ss : splited) + { + auto elements = util::split(ss, ','); + + int epoch = std::stoi(elements[0]); + + for (unsigned int i = 1; i < elements.size(); i++) + ts[epoch].insert(Trainer::str2TrainAction(elements[i])); + } + } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));} + + return ts; +} + int MacaonTrain::main() { auto od = getOptionsDescription(); @@ -83,13 +104,15 @@ int MacaonTrain::main() auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); auto batchSize = variables["batchSize"].as<int>(); - auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>(); auto rarityThreshold = variables["rarityThreshold"].as<float>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); auto pretrainedEmbeddings = variables["pretrainedEmbeddings"].as<std::string>(); + auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); + + auto trainStrategy = parseTrainStrategy(trainStrategyStr); torch::globalContext().setBenchmarkCuDNN(true); @@ -146,20 +169,15 @@ int MacaonTrain::main() { if (buffer != std::fgets(buffer, 1024, f)) break; + bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED"; float devScoreMean = std::stof(util::split(buffer, '\t').back()); - if (computeDevScore and (devScoreMean > bestDevScore or currentEpoch == dynamicOracleInterval)) - bestDevScore = devScoreMean; - if (!computeDevScore and (devScoreMean < bestDevScore or currentEpoch == dynamicOracleInterval)) + if (saved) bestDevScore = devScoreMean; currentEpoch++; } std::fclose(f); } - trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval); - if (!computeDevScore) - trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); - machine.getClassifier()->resetOptimizer(); auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; if (std::filesystem::exists(trainInfos)) @@ -167,9 +185,44 @@ int MacaonTrain::main() for (; currentEpoch < nbEpoch; currentEpoch++) { - trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval); + bool saved = false; + + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples)) + { + for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train")) + if (entry.is_regular_file()) + std::filesystem::remove(entry.path()); + + if (!computeDevScore) + for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev")) + if (entry.is_regular_file()) + std::filesystem::remove(entry.path()); + } + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) + { + trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + if (!computeDevScore) + trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + } + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) + { + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) + { + machine.resetClassifier(); + machine.getClassifier()->getNN()->registerEmbeddings(pretrainedEmbeddings); + machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); + } + + machine.getClassifier()->resetOptimizer(); + } + if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save)) + { + saved = true; + } + + trainer.makeDataLoader(modelPath/"examples/train"); if (!computeDevScore) - trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); + trainer.makeDevDataLoader(modelPath/"examples/dev"); float loss = trainer.epoch(printAdvancement); if (debug) @@ -201,13 +254,12 @@ int MacaonTrain::main() if (!devScoresStr.empty()) devScoresStr.pop_back(); devScoreMean /= devScores.size(); - bool saved = devScoreMean >= bestDevScore; - if (!computeDevScore) - saved = devScoreMean <= bestDevScore; + if (computeDevScore) + saved = saved or devScoreMean >= bestDevScore; + else + saved = saved or devScoreMean <= bestDevScore; - if (currentEpoch == dynamicOracleInterval) - saved = true; if (saved) { bestDevScore = devScoreMean; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 72b56e6..a12c6c5 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -5,33 +5,29 @@ Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), ba { } -void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) +void Trainer::makeDataLoader(std::filesystem::path dir) { - SubConfig config(goldConfig, goldConfig.getNbLines()); - - machine.trainMode(false); - machine.setDictsState(Dict::State::Closed); - - extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); - dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) +void Trainer::makeDevDataLoader(std::filesystem::path dir) +{ + devDataset.reset(new Dataset(dir)); + devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); +} + +void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); machine.setDictsState(Dict::State::Closed); - extractExamples(config, debug, dir, epoch, dynamicOracleInterval); - devDataset.reset(new Dataset(dir)); - - devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); + extractExamples(config, debug, dir, epoch, dynamicOracle); } -void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) +void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) { torch::AutoGradMode useGrad(false); @@ -45,22 +41,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p config.setState(config.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState()); - auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); - bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); - if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval)) - mustExtract = false; + auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); - if (!mustExtract) + if (std::filesystem::exists(currentEpochAllExtractedFile)) return; - bool dynamicOracle = epoch != 0; - fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); - for (auto & entry : std::filesystem::directory_iterator(dir)) - if (entry.is_regular_file()) - std::filesystem::remove(entry.path()); - int totalNbExamples = 0; while (true) @@ -88,7 +75,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); - if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "segmenter") + if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); @@ -127,7 +114,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(goldIndex); - examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile); + examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); transition->apply(config); config.addToHistory(transition->getName()); @@ -147,7 +134,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } for (auto & it : examplesPerState) - it.second.saveIfNeeded(it.first, dir, 0); + it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) @@ -240,7 +227,7 @@ float Trainer::evalOnDev(bool printAdvancement) return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } -void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold) +void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle) { if (currentExampleIndex-lastSavedIndex < (int)threshold) return; @@ -248,7 +235,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: return; auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); - auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1); + auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); @@ -340,3 +327,23 @@ void Trainer::fillDicts(SubConfig & config, bool debug) } } +Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) +{ + if (s == "ExtractGold") + return TrainAction::ExtractGold; + if (s == "ExtractDynamic") + return TrainAction::ExtractDynamic; + if (s == "DeleteExamples") + return TrainAction::DeleteExamples; + if (s == "ResetOptimizer") + return TrainAction::ResetOptimizer; + if (s == "ResetParameters") + return TrainAction::ResetParameters; + if (s == "Save") + return TrainAction::Save; + + util::myThrow(fmt::format("unknown TrainAction '{}'", s)); + + return TrainAction::ExtractGold; +} + -- GitLab