Skip to content
Snippets Groups Projects
Commit 9b517e71 authored by Franck Dary's avatar Franck Dary
Browse files

ConfigDataset uses less memory, and memory usage is more stable

parent fb37b7a7
No related branches found
No related tags found
No related merge requests found
......@@ -8,12 +8,13 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{
private :
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
torch::Tensor data;
std::size_t size_{0};
std::size_t contextSize{0};
public :
explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes);
explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
};
......
#include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes)
ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
{
if (contexts.size() != classes.size())
util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
size_ = contexts.size();
contextSize = contexts.back().size(0);
std::vector<torch::Tensor> total;
for (unsigned int i = 0; i < contexts.size(); i++)
{
total.emplace_back(contexts[i]);
total.emplace_back(classes[i]);
}
data = torch::cat(total);
}
torch::optional<size_t> ConfigDataset::size() const
{
return contexts.size();
return size_;
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
return {contexts[index], classes[index]};
return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment