diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 7aa878e62ce6d8beeabca4cd763353bb9023ff22..1c60de78ffc4cced19ed0138f45b6003bd419a35 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -8,12 +8,13 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset> { private : - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + torch::Tensor data; + std::size_t size_{0}; + std::size_t contextSize{0}; public : - explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes); + explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes); torch::optional<size_t> size() const override; torch::data::Example<> get(size_t index) override; }; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index e2d3853312fc657c4dfd31f08197238adbab7e47..439cbcc6d32abfeeff17fd090f15bf34a3b3f438 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,16 +1,29 @@ #include "ConfigDataset.hpp" -ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes) +ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) { + if (contexts.size() != classes.size()) + util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size())); + + size_ = contexts.size(); + contextSize = contexts.back().size(0); + std::vector<torch::Tensor> total; + for (unsigned int i = 0; i < contexts.size(); i++) + { + total.emplace_back(contexts[i]); + total.emplace_back(classes[i]); + } + + data = torch::cat(total); } torch::optional<size_t> ConfigDataset::size() const { - return contexts.size(); + return size_; } torch::data::Example<> ConfigDataset::get(size_t index) { - return {contexts[index], classes[index]}; + return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)}; }