From 9b517e7197af66c45ef083549b6a47e6cb56ebaa Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 15 Mar 2020 14:22:43 +0100 Subject: [PATCH] ConfigDataset uses less memory, and memory usage is more stable --- torch_modules/include/ConfigDataset.hpp | 7 ++++--- torch_modules/src/ConfigDataset.cpp | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 7aa878e..1c60de7 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -8,12 +8,13 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset> { private : - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + torch::Tensor data; + std::size_t size_{0}; + std::size_t contextSize{0}; public : - explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes); + explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes); torch::optional<size_t> size() const override; torch::data::Example<> get(size_t index) override; }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index e2d3853..439cbcc 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,16 +1,29 @@ #include "ConfigDataset.hpp" -ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes) +ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) { + if (contexts.size() != classes.size()) + util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size())); + + size_ = contexts.size(); + contextSize = contexts.back().size(0); + std::vector<torch::Tensor> total; + for (unsigned int i = 0; i < contexts.size(); i++) + { + total.emplace_back(contexts[i]); + total.emplace_back(classes[i]); + } + + data = torch::cat(total); } torch::optional<size_t> ConfigDataset::size() const { - return contexts.size(); + return size_; } torch::data::Example<> ConfigDataset::get(size_t index) { - return {contexts[index], classes[index]}; + return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)}; } -- GitLab