Skip to content
Snippets Groups Projects
ConfigDataset.cpp 652 B
Newer Older
#include "ConfigDataset.hpp"

ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
{
}

torch::optional<size_t> ConfigDataset::size() const
{
  return configs.size();
}

torch::data::Example<> ConfigDataset::get(size_t index)
{
  auto context = configs[index]->extractContext(5,5,dict);
  auto tensorClass = torch::zeros(nbClasses);
  tensorClass[classes[index]] = 1;
  return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};