From ecf7290b97c4de62a4e8f647b67c66b4c98e98f0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 13 Apr 2020 23:56:34 +0200 Subject: [PATCH] Added dynamical oracle and extracted examples are savec to the disk not to use too much memory --- decoder/src/Decoder.cpp | 3 +- torch_modules/include/ConfigDataset.hpp | 17 ++- torch_modules/src/ConfigDataset.cpp | 70 ++++++++--- trainer/include/Trainer.hpp | 15 ++- trainer/src/MacaonTrain.cpp | 30 ++--- trainer/src/Trainer.cpp | 147 ++++++++++++++++-------- 6 files changed, 196 insertions(+), 86 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 5c81d84..1c16521 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool { torch::AutoGradMode useGrad(false); machine.trainMode(false); + machine.getStrategy().reset(); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; @@ -27,9 +28,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto dictState = machine.getDict(config.getState()).getState(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); - machine.getDict(config.getState()).setState(dictState); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 1c60de7..59289f4 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -4,19 +4,24 @@ #include <torch/torch.h> #include "Config.hpp" -class ConfigDataset : public torch::data::Dataset<ConfigDataset> +class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::pair<torch::Tensor,torch::Tensor>> { private : - torch::Tensor data; std::size_t size_{0}; - std::size_t contextSize{0}; + std::vector<std::tuple<int,int,std::filesystem::path>> exampleLocations; + torch::Tensor loadedTensor; + std::optional<std::size_t> loadedTensorIndex; + std::size_t nextIndexToGive{0}; public : - explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes); - torch::optional<size_t> size() const override; - torch::data::Example<> get(size_t index) override; + explicit ConfigDataset(std::filesystem::path dir); + c10::optional<std::size_t> size() const override; + c10::optional<std::pair<torch::Tensor,torch::Tensor>> get_batch(std::size_t batchSize) override; + void reset() override; + void load(torch::serialize::InputArchive &) override; + void save(torch::serialize::OutputArchive &) const override; }; #endif diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 35f1942..e73e88f 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,30 +1,72 @@ #include "ConfigDataset.hpp" #include "NeuralNetwork.hpp" -ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) +ConfigDataset::ConfigDataset(std::filesystem::path dir) { - if (contexts.size() != classes.size()) - util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size())); + for (auto & entry : std::filesystem::directory_iterator(dir)) + if (entry.is_regular_file()) + { + auto splited = util::split(entry.path().stem().string(), '-'); + if (splited.size() != 2) + continue; + exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path())); + size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back()); + } +} + +c10::optional<std::size_t> ConfigDataset::size() const +{ + return size_; +} - size_ = contexts.size(); - contextSize = contexts.back().size(0); - std::vector<torch::Tensor> total; - for (unsigned int i = 0; i < contexts.size(); i++) +c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize) +{ + if (!loadedTensorIndex.has_value()) + { + loadedTensorIndex = 0; + nextIndexToGive = 0; + torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); + } + if ((int)nextIndexToGive >= loadedTensor.size(0)) { - total.emplace_back(contexts[i]); - total.emplace_back(classes[i]); + nextIndexToGive = 0; + loadedTensorIndex = loadedTensorIndex.value() + 1; + + if (loadedTensorIndex >= exampleLocations.size()) + return c10::optional<std::pair<torch::Tensor,torch::Tensor>>(); + + torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); } - data = torch::cat(total); + std::pair<torch::Tensor, torch::Tensor> batch; + if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0)) + { + batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); + batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); + nextIndexToGive += batchSize; + } + else + { + batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); + batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); + nextIndexToGive = loadedTensor.size(0); + } + + return batch; } -torch::optional<size_t> ConfigDataset::size() const +void ConfigDataset::reset() +{ + std::random_shuffle(exampleLocations.begin(), exampleLocations.end()); + loadedTensorIndex = std::optional<std::size_t>(); + nextIndexToGive = 0; +} + +void ConfigDataset::load(torch::serialize::InputArchive &) { - return size_; } -torch::data::Example<> ConfigDataset::get(size_t index) +void ConfigDataset::save(torch::serialize::OutputArchive &) const { - return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)}; } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index b5a548c..03e7616 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -10,28 +10,31 @@ class Trainer private : using Dataset = ConfigDataset; - using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >; + using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>; private : ReadingMachine & machine; + std::unique_ptr<Dataset> trainDataset{nullptr}; + std::unique_ptr<Dataset> devDataset{nullptr}; DataLoader dataLoader{nullptr}; DataLoader devDataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{64}; + int batchSize; int nbExamples{0}; private : - void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes); + void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement); + void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir); public : - Trainer(ReadingMachine & machine); - void createDataset(SubConfig & goldConfig, bool debug); - void createDevDataset(SubConfig & goldConfig, bool debug); + Trainer(ReadingMachine & machine, int batchSize); + void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); + void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); void loadOptimizer(std::filesystem::path path); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 25311df..280643a 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -31,6 +31,10 @@ po::options_description MacaonTrain::getOptionsDescription() "Raw text file of the development corpus") ("nbEpochs,n", po::value<int>()->default_value(5), "Number of training epochs") + ("batchSize", po::value<int>()->default_value(64), + "Number of examples per batch") + ("dynamicOracleInterval", po::value<int>()->default_value(-1), + "Number of examples per batch") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") ("help,h", "Produce this help message"); @@ -90,6 +94,8 @@ int MacaonTrain::main() auto devTsvFile = variables["devTSV"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); + auto batchSize = variables["batchSize"].as<int>(); + auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool computeDevScore = variables.count("devScore") == 0 ? false : true; @@ -115,19 +121,11 @@ int MacaonTrain::main() BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); - SubConfig config(goldConfig, goldConfig.getNbLines()); fillDicts(machine, goldConfig); - Trainer trainer(machine); - trainer.createDataset(config, debug); - if (!computeDevScore) - { - machine.getStrategy().reset(); - SubConfig devConfig(devGoldConfig, devGoldConfig.getNbLines()); - trainer.createDevDataset(devConfig, debug); - } + Trainer trainer(machine, batchSize); Decoder decoder(machine); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); @@ -154,14 +152,21 @@ int MacaonTrain::main() std::fclose(f); } + trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval); + if (!computeDevScore) + trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); + auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; if (std::filesystem::exists(trainInfos)) trainer.loadOptimizer(optimizerCheckpoint); for (; currentEpoch < nbEpoch; currentEpoch++) { + trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, dynamicOracleInterval); + if (!computeDevScore) + trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); + float loss = trainer.epoch(printAdvancement); - machine.getStrategy().reset(); if (debug) fmt::print(stderr, "Decoding dev :\n"); std::vector<std::pair<float,std::string>> devScores; @@ -169,7 +174,6 @@ int MacaonTrain::main() { auto devConfig = devGoldConfig; decoder.decode(devConfig, 1, debug, printAdvancement); - machine.getStrategy().reset(); decoder.evaluate(devConfig, modelPath, devTsvFile); devScores = decoder.getF1Scores(machine.getPredicted()); } @@ -192,9 +196,9 @@ int MacaonTrain::main() if (!devScoresStr.empty()) devScoresStr.pop_back(); devScoreMean /= devScores.size(); - bool saved = devScoreMean > bestDevScore; + bool saved = devScoreMean >= bestDevScore; if (!computeDevScore) - saved = devScoreMean < bestDevScore; + saved = devScoreMean <= bestDevScore; if (saved) { bestDevScore = devScoreMean; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d69d74a..5d3b2ec 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,42 +1,77 @@ #include "Trainer.hpp" #include "SubConfig.hpp" -Trainer::Trainer(ReadingMachine & machine) : machine(machine) +Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { } -void Trainer::createDataset(SubConfig & config, bool debug) +void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { - machine.trainMode(true); - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + SubConfig config(goldConfig, goldConfig.getNbLines()); - extractExamples(config, debug, contexts, classes); + extractExamples(config, debug, dir, epoch, dynamicOracleInterval); + trainDataset.reset(new Dataset(dir)); - nbExamples = classes.size(); + nbExamples = trainDataset->size().value(); - dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); + dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); + if (optimizer.get() == nullptr) + optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } -void Trainer::createDevDataset(SubConfig & config, bool debug) +void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { - machine.trainMode(false); - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + SubConfig config(goldConfig, goldConfig.getNbLines()); - extractExamples(config, debug, contexts, classes); + extractExamples(config, debug, dir, epoch, dynamicOracleInterval); + devDataset.reset(new Dataset(dir)); - devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); + devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes) +void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir) { - fmt::print(stderr, "[{}] Starting to extract examples\n", util::getTime()); + auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); + auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1); + torch::save(tensorToSave, dir/filename); + lastSavedIndex = currentExampleIndex; + contexts.clear(); + classes.clear(); +} + +void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) +{ + torch::AutoGradMode useGrad(false); + machine.trainMode(false); + + int maxNbExamplesPerFile = 250000; + int currentExampleIndex = 0; + int lastSavedIndex = 0; + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); + machine.getStrategy().reset(); + + auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); + bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); + if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval)) + mustExtract = false; + + if (!mustExtract) + return; + + bool dynamicOracle = epoch != 0; + + fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); + + for (auto & entry : std::filesystem::directory_iterator(dir)) + if (entry.is_regular_file()) + std::filesystem::remove(entry.path()); while (true) { @@ -46,31 +81,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); - if (!transition) - { - config.printForDebug(stderr); - util::myThrow("No transition appliable !"); - } - - if (config.isMultiword(config.getWordIndex())) - if (transition->getName() == "ADDCHARTOWORD") - { - config.printForDebug(stderr); - - auto & splitTrans = config.getAppliableSplitTransitions(); - fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size()); - for (auto & trans : splitTrans) - fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName()); - util::myThrow(fmt::format("Transition should have been a split")); - } - if (transition->getName() == "ENDWORD") - if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0)) - { - config.printForDebug(stderr); - util::myThrow(fmt::format("Words don't match")); - } - std::vector<std::vector<long>> context; try @@ -83,12 +93,51 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } + Transition * transition = nullptr; + + if (dynamicOracle and config.getState() != "tokenizer") + { + auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); + auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); + + int chosenTransition = -1; + float bestScore = std::numeric_limits<float>::min(); + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) + { + chosenTransition = i; + bestScore = score; + } + } + + transition = machine.getTransitionSet().getTransition(chosenTransition); + } + else + { + transition = machine.getTransitionSet().getBestAppliableTransition(config); + } + + if (!transition) + { + config.printForDebug(stderr); + util::myThrow("No transition appliable !"); + } + int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; for (auto & element : context) + { + currentExampleIndex++; classes.emplace_back(gold); + } + + if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) + saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); transition->apply(config); config.addToHistory(transition->getName()); @@ -106,7 +155,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: config.update(); } - fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(classes.size())); + if (!contexts.empty()) + saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); + + std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); + if (!f) + util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); + std::fclose(f); + + fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) @@ -129,8 +186,8 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) optimizer->zero_grad(); - auto data = batch.data; - auto labels = batch.target.squeeze(); + auto data = batch.first; + auto labels = batch.second; auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) -- GitLab