#include "ConfigDataset.hpp" ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) { if (contexts.size() != classes.size()) util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size())); size_ = contexts.size(); contextSize = contexts.back().size(0); std::vector<torch::Tensor> total; for (unsigned int i = 0; i < contexts.size(); i++) { total.emplace_back(contexts[i]); total.emplace_back(classes[i]); } data = torch::cat(total); } torch::optional<size_t> ConfigDataset::size() const { return size_; } torch::data::Example<> ConfigDataset::get(size_t index) { return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)}; }