Skip to content
Snippets Groups Projects
ConfigDataset.cpp 818 B
Newer Older
  • Learn to ignore specific revisions
  • #include "ConfigDataset.hpp"
    
    
    ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
    
      if (contexts.size() != classes.size())
        util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
    
      size_ = contexts.size();
      contextSize = contexts.back().size(0);
      std::vector<torch::Tensor> total;
      for (unsigned int i = 0; i < contexts.size(); i++)
      {
        total.emplace_back(contexts[i]);
        total.emplace_back(classes[i]);
      }
    
      data = torch::cat(total);
    
    }
    
    torch::optional<size_t> ConfigDataset::size() const
    {
    
    }
    
    torch::data::Example<> ConfigDataset::get(size_t index)
    {
    
      return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};