Skip to content
Snippets Groups Projects
ConfigDataset.cpp 652 B
Newer Older
  • Learn to ignore specific revisions
  • #include "ConfigDataset.hpp"
    
    
    ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
    
    {
    }
    
    torch::optional<size_t> ConfigDataset::size() const
    {
    
      return configs.size();
    
    }
    
    torch::data::Example<> ConfigDataset::get(size_t index)
    {
    
      auto context = configs[index]->extractContext(5,5,dict);
      auto tensorClass = torch::zeros(nbClasses);
      tensorClass[classes[index]] = 1;
    
      return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};