From a81d26c906425fbfde4cdd03d55549eb2d1aaff2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 21 Apr 2020 15:30:10 +0200
Subject: [PATCH] better repartition of batches between states

---
 torch_modules/include/ConfigDataset.hpp |  28 ++++-
 torch_modules/src/ConfigDataset.cpp     | 131 ++++++++++++++++++------
 trainer/include/Trainer.hpp             |   5 +-
 trainer/src/Trainer.cpp                 |  61 ++++++-----
 4 files changed, 162 insertions(+), 63 deletions(-)

diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index ee8d46c..7ea2335 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -8,11 +8,30 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
 {
   private :
 
+  struct Holder
+  {
+    std::string state;
+    std::vector<std::string> files;
+    torch::Tensor loadedTensor;
+    int loadedTensorIndex{0};
+    int nextIndexToGive{0};
+    std::size_t size_{0};
+    std::size_t nbGiven{0};
+
+    Holder(std::string state);
+    void addFile(std::string filename, int filesize);
+    void reset();
+    std::size_t size() const;
+    std::size_t sizeLeft() const;
+    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize);
+  };
+
+  private :
+
   std::size_t size_{0};
-  std::vector<std::tuple<int,int,std::filesystem::path,std::string>> exampleLocations;
-  torch::Tensor loadedTensor;
-  std::optional<std::size_t> loadedTensorIndex;
-  std::size_t nextIndexToGive{0};
+  std::map<std::string,Holder> holders;
+  std::map<std::string,int> nbToGive;
+  std::vector<std::string> order;
 
   public :
 
@@ -22,6 +41,7 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase
   void reset() override;
   void load(torch::serialize::InputArchive &) override;
   void save(torch::serialize::OutputArchive &) const override;
+  void computeNbToGive();
 };
 
 #endif
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 3a4a2c5..3a4e646 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -11,10 +11,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir)
         continue;
       auto state = util::split(stem, '_')[0];
       auto splited = util::split(util::split(stem, '_')[1], '-');
-      exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
-      size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
+      int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
+      size_ += fileSize;
+      if (!holders.count(state))
+      {
+        holders.emplace(state, Holder(state));
+        order.emplace_back(state);
+      }
+      holders.at(state).addFile(entry.path().string(), fileSize);
     }
-
 }
 
 c10::optional<std::size_t> ConfigDataset::size() const
@@ -24,47 +29,45 @@ c10::optional<std::size_t> ConfigDataset::size() const
 
 c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
 {
-  if (!loadedTensorIndex.has_value())
+  std::random_shuffle(order.begin(), order.end());
+
+  for (auto & state : order)
   {
-    loadedTensorIndex = 0;
-    nextIndexToGive = 0;
-    torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
+    if (nbToGive.at(state) > 0)
+    {
+      nbToGive.at(state)--;
+      auto res = holders.at(state).get_batch(batchSize);
+      if (res.has_value())
+        return res;
+      else
+        nbToGive.at(state) = 0;
+    }
   }
-  if ((int)nextIndexToGive >= loadedTensor.size(0))
-  {
-    nextIndexToGive = 0;
-    loadedTensorIndex = loadedTensorIndex.value() + 1;
-
-    if (loadedTensorIndex >= exampleLocations.size())
-      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
 
-    torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
-  }
+  computeNbToGive();
 
-  std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
-  if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
-  {
-    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
-    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
-    nextIndexToGive += batchSize;
-  }
-  else
+  for (auto & state : order)
   {
-    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
-    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
-    nextIndexToGive = loadedTensor.size(0);
+    if (nbToGive.at(state) > 0)
+    {
+      nbToGive.at(state)--;
+      auto res = holders.at(state).get_batch(batchSize);
+      if (res.has_value())
+        return res;
+      else
+        nbToGive.at(state) = 0;
+    }
   }
 
-  std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);
-
-  return batch;
+  return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
 }
 
 void ConfigDataset::reset()
 {
-  std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
-  loadedTensorIndex = std::optional<std::size_t>();
-  nextIndexToGive = 0;
+  for (auto & it : holders)
+    it.second.reset();
+
+  computeNbToGive();
 }
 
 void ConfigDataset::load(torch::serialize::InputArchive &)
@@ -75,3 +78,65 @@ void ConfigDataset::save(torch::serialize::OutputArchive &) const
 {
 }
 
+void ConfigDataset::Holder::addFile(std::string filename, int filesize)
+{
+  files.emplace_back(filename);
+  size_ += filesize;
+}
+
+void ConfigDataset::Holder::reset()
+{
+  std::random_shuffle(files.begin(), files.end());
+  loadedTensorIndex = 0;
+  nextIndexToGive = 0;
+  nbGiven = 0;
+  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
+}
+
+c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
+{
+  if (loadedTensorIndex >= (int)files.size())
+    return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
+  if (nextIndexToGive >= loadedTensor.size(0))
+  {
+    loadedTensorIndex++;
+    if (loadedTensorIndex >= (int)files.size())
+      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
+    nextIndexToGive = 0;
+    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
+  }
+
+  int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
+  nbGiven += nbElementsToGive;
+  auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
+  nextIndexToGive += nbElementsToGive;
+  return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
+}
+
+ConfigDataset::Holder::Holder(std::string state) : state(state)
+{
+}
+
+std::size_t ConfigDataset::Holder::size() const
+{
+  return size_;
+}
+
+std::size_t ConfigDataset::Holder::sizeLeft() const
+{
+  return size_-nbGiven;
+}
+
+void ConfigDataset::computeNbToGive()
+{
+  int smallestSize = std::numeric_limits<int>::max();
+  for (auto & it : holders)
+  {
+    int sizeLeft = it.second.sizeLeft();
+    if (sizeLeft > 0 and sizeLeft < smallestSize)
+      smallestSize = sizeLeft;
+  }
+  for (auto & it : holders)
+    nbToGive[it.first] = std::floor(1.0*it.second.sizeLeft()/smallestSize);
+}
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 91aedbe..c087469 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -16,6 +16,10 @@ class Trainer
 
     int currentExampleIndex{0};
     int lastSavedIndex{0};
+
+    void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold);
+    void addContext(std::vector<std::vector<long>> & context);
+    void addClass(int goldIndex);
   };
 
   private :
@@ -37,7 +41,6 @@ class Trainer
 
   void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
   float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
-void saveExamples(std::string state, Examples & examples, std::filesystem::path dir);
 
   public :
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index f76f220..23b131b 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -33,16 +33,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
   devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 }
 
-void Trainer::saveExamples(std::string state, Examples & examples, std::filesystem::path dir)
-{
-  auto tensorToSave = torch::cat({torch::stack(examples.contexts), torch::stack(examples.classes)}, 1);
-  auto filename = fmt::format("{}_{}-{}.tensor", state, examples.lastSavedIndex, examples.currentExampleIndex-1);
-  torch::save(tensorToSave, dir/filename);
-  examples.lastSavedIndex = examples.currentExampleIndex;
-  examples.contexts.clear();
-  examples.classes.clear();
-}
-
 void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
 {
   torch::AutoGradMode useGrad(false);
@@ -85,16 +75,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
 
     std::vector<std::vector<long>> context;
 
-    auto & contexts = examplesPerState[config.getState()].contexts;
-    auto & classes = examplesPerState[config.getState()].classes;
-    auto & currentExampleIndex = examplesPerState[config.getState()].currentExampleIndex;
-    auto & lastSavedIndex = examplesPerState[config.getState()].lastSavedIndex;
-
     try
     {
       context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
-      for (auto & element : context)
-        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
@@ -134,15 +117,12 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
-    gold[0] = goldIndex;
 
-    currentExampleIndex += context.size();
     totalNbExamples += context.size();
-    classes.insert(classes.end(), context.size(), gold);
 
-    if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
-      saveExamples(config.getState(), examplesPerState[config.getState()], dir);
+    examplesPerState[config.getState()].addContext(context);
+    examplesPerState[config.getState()].addClass(goldIndex);
+    examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
@@ -162,8 +142,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   }
 
   for (auto & it : examplesPerState)
-    if (!it.second.contexts.empty())
-      saveExamples(it.first, it.second, dir);
+    it.second.saveIfNeeded(it.first, dir, 0);
 
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)
@@ -258,3 +237,35 @@ float Trainer::evalOnDev(bool printAdvancement)
   return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
 }
 
+void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold)
+{
+  if (currentExampleIndex-lastSavedIndex < (int)threshold)
+    return;
+  if (contexts.empty())
+    return;
+
+  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
+  auto filename = fmt::format("{}_{}-{}.tensor", state, lastSavedIndex, currentExampleIndex-1);
+  torch::save(tensorToSave, dir/filename);
+  lastSavedIndex = currentExampleIndex;
+  contexts.clear();
+  classes.clear();
+}
+
+void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
+{
+  for (auto & element : context)
+    contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
+
+  currentExampleIndex += context.size();
+}
+
+void Trainer::Examples::addClass(int goldIndex)
+{
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
+    gold[0] = goldIndex;
+
+    while (classes.size() < contexts.size())
+      classes.emplace_back(gold);
+}
+
-- 
GitLab