Newer
Older
ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
{
}
torch::optional<size_t> ConfigDataset::size() const
{
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
auto context = configs[index]->extractContext(5,5,dict);
auto tensorClass = torch::zeros(nbClasses);
tensorClass[classes[index]] = 1;
return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};