From a81d26c906425fbfde4cdd03d55549eb2d1aaff2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 21 Apr 2020 15:30:10 +0200 Subject: [PATCH] better repartition of batches between states --- torch_modules/include/ConfigDataset.hpp | 28 ++++- torch_modules/src/ConfigDataset.cpp | 131 ++++++++++++++++++------ trainer/include/Trainer.hpp | 5 +- trainer/src/Trainer.cpp | 61 ++++++----- 4 files changed, 162 insertions(+), 63 deletions(-) diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index ee8d46c..7ea2335 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -8,11 +8,30 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase { private : + struct Holder + { + std::string state; + std::vector<std::string> files; + torch::Tensor loadedTensor; + int loadedTensorIndex{0}; + int nextIndexToGive{0}; + std::size_t size_{0}; + std::size_t nbGiven{0}; + + Holder(std::string state); + void addFile(std::string filename, int filesize); + void reset(); + std::size_t size() const; + std::size_t sizeLeft() const; + c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize); + }; + + private : + std::size_t size_{0}; - std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations; - torch::Tensor loadedTensor; - std::optional<std::size_t> loadedTensorIndex; - std::size_t nextIndexToGive{0}; + std::map<std::string,Holder> holders; + std::map<std::string,int> nbToGive; + std::vector<std::string> order; public : @@ -22,6 +41,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase void reset() override; void load(torch::serialize::InputArchive &) override; void save(torch::serialize::OutputArchive &) const override; + void computeNbToGive(); }; #endif diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 3a4a2c5..3a4e646 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -11,10 +11,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) continue; auto state = util::split(stem, '_')[0]; auto splited = util::split(util::split(stem, '_')[1], '-'); - exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state)); - size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back()); + int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]); + size_ += fileSize; + if (!holders.count(state)) + { + holders.emplace(state, Holder(state)); + order.emplace_back(state); + } + holders.at(state).addFile(entry.path().string(), fileSize); } - } c10::optional<std::size_t> ConfigDataset::size() const @@ -24,47 +29,45 @@ c10::optional<std::size_t> ConfigDataset::size() const c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize) { - if (!loadedTensorIndex.has_value()) + std::random_shuffle(order.begin(), order.end()); + + for (auto & state : order) { - loadedTensorIndex = 0; - nextIndexToGive = 0; - torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); + if (nbToGive.at(state) > 0) + { + nbToGive.at(state)--; + auto res = holders.at(state).get_batch(batchSize); + if (res.has_value()) + return res; + else + nbToGive.at(state) = 0; + } } - if ((int)nextIndexToGive >= loadedTensor.size(0)) - { - nextIndexToGive = 0; - loadedTensorIndex = loadedTensorIndex.value() + 1; - - if (loadedTensorIndex >= exampleLocations.size()) - return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); - torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); - } + computeNbToGive(); - std::tuple<torch::Tensor, torch::Tensor, std::string> batch; - if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0)) - { - std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); - std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); - nextIndexToGive += batchSize; - } - else + for (auto & state : order) { - std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); - std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); - nextIndexToGive = loadedTensor.size(0); + if (nbToGive.at(state) > 0) + { + nbToGive.at(state)--; + auto res = holders.at(state).get_batch(batchSize); + if (res.has_value()) + return res; + else + nbToGive.at(state) = 0; + } } - std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]); - - return batch; + return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); } void ConfigDataset::reset() { - std::random_shuffle(exampleLocations.begin(), exampleLocations.end()); - loadedTensorIndex = std::optional<std::size_t>(); - nextIndexToGive = 0; + for (auto & it : holders) + it.second.reset(); + + computeNbToGive(); } void ConfigDataset::load(torch::serialize::InputArchive &) @@ -75,3 +78,65 @@ void ConfigDataset::save(torch::serialize::OutputArchive &) const { } +void ConfigDataset::Holder::addFile(std::string filename, int filesize) +{ + files.emplace_back(filename); + size_ += filesize; +} + +void ConfigDataset::Holder::reset() +{ + std::random_shuffle(files.begin(), files.end()); + loadedTensorIndex = 0; + nextIndexToGive = 0; + nbGiven = 0; + torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); +} + +c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) +{ + if (loadedTensorIndex >= (int)files.size()) + return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); + if (nextIndexToGive >= loadedTensor.size(0)) + { + loadedTensorIndex++; + if (loadedTensorIndex >= (int)files.size()) + return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); + nextIndexToGive = 0; + torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); + } + + int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive); + nbGiven += nbElementsToGive; + auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive); + nextIndexToGive += nbElementsToGive; + return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state); +} + +ConfigDataset::Holder::Holder(std::string state) : state(state) +{ +} + +std::size_t ConfigDataset::Holder::size() const +{ + return size_; +} + +std::size_t ConfigDataset::Holder::sizeLeft() const +{ + return size_-nbGiven; +} + +void ConfigDataset::computeNbToGive() +{ + int smallestSize = std::numeric_limits<int>::max(); + for (auto & it : holders) + { + int sizeLeft = it.second.sizeLeft(); + if (sizeLeft > 0 and sizeLeft < smallestSize) + smallestSize = sizeLeft; + } + for (auto & it : holders) + nbToGive[it.first] = std::floor(1.0*it.second.sizeLeft()/smallestSize); +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 91aedbe..c087469 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -16,6 +16,10 @@ class Trainer int currentExampleIndex{0}; int lastSavedIndex{0}; + + void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold); + void addContext(std::vector<std::vector<long>> & context); + void addClass(int goldIndex); }; private : @@ -37,7 +41,6 @@ class Trainer void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); -void saveExamples(std::string state, Examples & examples, std::filesystem::path dir); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index f76f220..23b131b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -33,16 +33,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir) -{ - auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1); - auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1); - torch::save(tensorToSave, dir/filename); - examples.lastSavedIndex = examples.currentExampleIndex; - examples.contexts.clear(); - examples.classes.clear(); -} - void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); @@ -85,16 +75,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::vector<std::vector<long>> context; - auto & contexts = examplesPerState[config.getState()].contexts; - auto & classes = examplesPerState[config.getState()].classes; - auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex; - auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex; - try { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - for (auto & element : context) - contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); @@ -134,15 +117,12 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); - gold[0] = goldIndex; - currentExampleIndex += context.size(); totalNbExamples += context.size(); - classes.insert(classes.end(), context.size(), gold); - if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) - saveExamples(config.getState(), examplesPerState[config.getState()], dir); + examplesPerState[config.getState()].addContext(context); + examplesPerState[config.getState()].addClass(goldIndex); + examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile); transition->apply(config); config.addToHistory(transition->getName()); @@ -162,8 +142,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } for (auto & it : examplesPerState) - if (!it.second.contexts.empty()) - saveExamples(it.first, it.second, dir); + it.second.saveIfNeeded(it.first, dir, 0); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) @@ -258,3 +237,35 @@ float Trainer::evalOnDev(bool printAdvancement) return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } +void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold) +{ + if (currentExampleIndex-lastSavedIndex < (int)threshold) + return; + if (contexts.empty()) + return; + + auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); + auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1); + torch::save(tensorToSave, dir/filename); + lastSavedIndex = currentExampleIndex; + contexts.clear(); + classes.clear(); +} + +void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) +{ + for (auto & element : context) + contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); + + currentExampleIndex += context.size(); +} + +void Trainer::Examples::addClass(int goldIndex) +{ + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); + gold[0] = goldIndex; + + while (classes.size() < contexts.size()) + classes.emplace_back(gold); +} + -- GitLab