Skip to content
Snippets Groups Projects
Commit a81d26c9 authored by Franck Dary's avatar Franck Dary
Browse files

better repartition of batches between states

parent a31419f9
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,30 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
{
private :
struct Holder
{
std::string state;
std::vector<std::string> files;
torch::Tensor loadedTensor;
int loadedTensorIndex{0};
int nextIndexToGive{0};
std::size_t size_{0};
std::size_t nbGiven{0};
Holder(std::string state);
void addFile(std::string filename, int filesize);
void reset();
std::size_t size() const;
std::size_t sizeLeft() const;
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize);
};
private :
std::size_t size_{0};
std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations;
torch::Tensor loadedTensor;
std::optional<std::size_t> loadedTensorIndex;
std::size_t nextIndexToGive{0};
std::map<std::string,Holder> holders;
std::map<std::string,int> nbToGive;
std::vector<std::string> order;
public :
......@@ -22,6 +41,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
void reset() override;
void load(torch::serialize::InputArchive &) override;
void save(torch::serialize::OutputArchive &) const override;
void computeNbToGive();
};
#endif
......@@ -11,10 +11,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
continue;
auto state = util::split(stem, '_')[0];
auto splited = util::split(util::split(stem, '_')[1], '-');
exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
size_ += fileSize;
if (!holders.count(state))
{
holders.emplace(state, Holder(state));
order.emplace_back(state);
}
holders.at(state).addFile(entry.path().string(), fileSize);
}
}
c10::optional<std::size_t> ConfigDataset::size() const
......@@ -24,47 +29,45 @@ c10::optional<std::size_t> ConfigDataset::size() const
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
{
if (!loadedTensorIndex.has_value())
std::random_shuffle(order.begin(), order.end());
for (auto & state : order)
{
loadedTensorIndex = 0;
nextIndexToGive = 0;
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
}
if ((int)nextIndexToGive >= loadedTensor.size(0))
{
nextIndexToGive = 0;
loadedTensorIndex = loadedTensorIndex.value() + 1;
if (loadedTensorIndex >= exampleLocations.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
computeNbToGive();
std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
{
std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive += batchSize;
}
else
for (auto & state : order)
{
std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive = loadedTensor.size(0);
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
}
std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);
return batch;
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
}
void ConfigDataset::reset()
{
std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
loadedTensorIndex = std::optional<std::size_t>();
nextIndexToGive = 0;
for (auto & it : holders)
it.second.reset();
computeNbToGive();
}
void ConfigDataset::load(torch::serialize::InputArchive &)
......@@ -75,3 +78,65 @@ void ConfigDataset::save(torch::serialize::OutputArchive &) const
{
}
void ConfigDataset::Holder::addFile(std::string filename, int filesize)
{
files.emplace_back(filename);
size_ += filesize;
}
void ConfigDataset::Holder::reset()
{
std::random_shuffle(files.begin(), files.end());
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
{
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
if (nextIndexToGive >= loadedTensor.size(0))
{
loadedTensorIndex++;
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
nbGiven += nbElementsToGive;
auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
nextIndexToGive += nbElementsToGive;
return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
}
ConfigDataset::Holder::Holder(std::string state) : state(state)
{
}
std::size_t ConfigDataset::Holder::size() const
{
return size_;
}
std::size_t ConfigDataset::Holder::sizeLeft() const
{
return size_-nbGiven;
}
void ConfigDataset::computeNbToGive()
{
int smallestSize = std::numeric_limits<int>::max();
for (auto & it : holders)
{
int sizeLeft = it.second.sizeLeft();
if (sizeLeft > 0 and sizeLeft < smallestSize)
smallestSize = sizeLeft;
}
for (auto & it : holders)
nbToGive[it.first] = std::floor(1.0*it.second.sizeLeft()/smallestSize);
}
......@@ -16,6 +16,10 @@ class Trainer
int currentExampleIndex{0};
int lastSavedIndex{0};
void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold);
void addContext(std::vector<std::vector<long>> & context);
void addClass(int goldIndex);
};
private :
......@@ -37,7 +41,6 @@ class Trainer
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void saveExamples(std::string state, Examples & examples, std::filesystem::path dir);
public :
......
......@@ -33,16 +33,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir)
{
auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
examples.lastSavedIndex = examples.currentExampleIndex;
examples.contexts.clear();
examples.classes.clear();
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
{
torch::AutoGradMode useGrad(false);
......@@ -85,16 +75,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
std::vector<std::vector<long>> context;
auto & contexts = examplesPerState[config.getState()].contexts;
auto & classes = examplesPerState[config.getState()].classes;
auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex;
auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex;
try
{
context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
......@@ -134,15 +117,12 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
currentExampleIndex += context.size();
totalNbExamples += context.size();
classes.insert(classes.end(), context.size(), gold);
if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
saveExamples(config.getState(), examplesPerState[config.getState()], dir);
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(goldIndex);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
transition->apply(config);
config.addToHistory(transition->getName());
......@@ -162,8 +142,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
}
for (auto & it : examplesPerState)
if (!it.second.contexts.empty())
saveExamples(it.first, it.second, dir);
it.second.saveIfNeeded(it.first, dir, 0);
std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
if (!f)
......@@ -258,3 +237,35 @@ float Trainer::evalOnDev(bool printAdvancement)
return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
{
if (currentExampleIndex-lastSavedIndex < (int)threshold)
return;
if (contexts.empty())
return;
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
contexts.clear();
classes.clear();
}
void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
currentExampleIndex += context.size();
}
void Trainer::Examples::addClass(int goldIndex)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
while (classes.size() < contexts.size())
classes.emplace_back(gold);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment