#include "ConfigDataset.hpp"
#include "NeuralNetwork.hpp"

ConfigDataset::ConfigDataset(std::filesystem::path dir)
{
  for (auto & entry : std::filesystem::directory_iterator(dir))
    if (entry.is_regular_file())
    {
      auto stem = entry.path().stem().string();
      if (stem == "extracted")
        continue;
      auto state = util::split(stem, '_')[0];
      auto splited = util::split(util::split(stem, '_')[1], '-');
      exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
      size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
    }

}

c10::optional<std::size_t> ConfigDataset::size() const
{
  return size_;
}

c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
{
  if (!loadedTensorIndex.has_value())
  {
    loadedTensorIndex = 0;
    nextIndexToGive = 0;
    torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
  }
  if ((int)nextIndexToGive >= loadedTensor.size(0))
  {
    nextIndexToGive = 0;
    loadedTensorIndex = loadedTensorIndex.value() + 1;

    if (loadedTensorIndex >= exampleLocations.size())
      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();

    torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
  }

  std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
  if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
  {
    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
    nextIndexToGive += batchSize;
  }
  else
  {
    std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
    std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
    nextIndexToGive = loadedTensor.size(0);
  }

  std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);

  return batch;
}

void ConfigDataset::reset()
{
  std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
  loadedTensorIndex = std::optional<std::size_t>();
  nextIndexToGive = 0;
}

void ConfigDataset::load(torch::serialize::InputArchive &)
{
}

void ConfigDataset::save(torch::serialize::OutputArchive &) const
{
}