From 9b517e7197af66c45ef083549b6a47e6cb56ebaa Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 15 Mar 2020 14:22:43 +0100
Subject: [PATCH] ConfigDataset uses less memory, and memory usage is more
 stable

---
 torch_modules/include/ConfigDataset.hpp |  7 ++++---
 torch_modules/src/ConfigDataset.cpp     | 19 ++++++++++++++++---
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index 7aa878e..1c60de7 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -8,12 +8,13 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
 {
   private :
 
-  std::vector<torch::Tensor> contexts;
-  std::vector<torch::Tensor> classes;
+  torch::Tensor data;
+  std::size_t size_{0};
+  std::size_t contextSize{0};
 
   public :
 
-  explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes);
+  explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes);
   torch::optional<size_t> size() const override;
   torch::data::Example<> get(size_t index) override;
 };
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index e2d3853..439cbcc 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -1,16 +1,29 @@
 #include "ConfigDataset.hpp"
 
-ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes)
+ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
 {
+  if (contexts.size() != classes.size())
+    util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
+
+  size_ = contexts.size();
+  contextSize = contexts.back().size(0);
+  std::vector<torch::Tensor> total;
+  for (unsigned int i = 0; i < contexts.size(); i++)
+  {
+    total.emplace_back(contexts[i]);
+    total.emplace_back(classes[i]);
+  }
+
+  data = torch::cat(total);
 }
 
 torch::optional<size_t> ConfigDataset::size() const
 {
-  return contexts.size();
+  return size_;
 }
 
 torch::data::Example<> ConfigDataset::get(size_t index)
 {
-  return {contexts[index], classes[index]};
+  return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
 }
 
-- 
GitLab