Newer
Older
ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
if (contexts.size() != classes.size())
util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
size_ = contexts.size();
contextSize = contexts.back().size(0);
std::vector<torch::Tensor> total;
for (unsigned int i = 0; i < contexts.size(); i++)
{
total.emplace_back(contexts[i]);
total.emplace_back(classes[i]);
}
data = torch::cat(total);
}
torch::optional<size_t> ConfigDataset::size() const
{
return size_;
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};