#include "Dict.hpp"
#include "util.hpp"

Dict::Dict(State state)
{
  setState(state);
  insert(unknownValueStr);
  insert(nullValueStr);
  insert(emptyValueStr);
}

Dict::Dict(const char * filename, State state)
{
  readFromFile(filename);
  setState(state);
}

void Dict::readFromFile(const char * filename)
{
  std::FILE * file = std::fopen(filename, "r");

  if (!file)
    util::myThrow(fmt::format("could not open file '{}'", filename));

  char buffer[1048];
  if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
    util::myThrow(fmt::format("file '{}' bad format", filename));

  Encoding encoding{Encoding::Ascii};
  if (std::string(buffer) == "Ascii")
    encoding = Encoding::Ascii;
  else if (std::string(buffer) == "Binary")
    encoding = Encoding::Binary;
  else
    util::myThrow(fmt::format("file '{}' bad format", filename));

  int nbEntries;

  if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
    util::myThrow(fmt::format("file '{}' bad format", filename));

  elementsToIndexes.reserve(nbEntries);

  int entryIndex;
  int nbOccsEntry;
  char entryString[maxEntrySize+1];
  for (int i = 0; i < nbEntries; i++)
  {
    if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
      util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));

    elementsToIndexes[entryString] = entryIndex;
    while (nbOccs.size() <= entryIndex)
      nbOccs.emplace_back(0);
    nbOccs[entryIndex] = nbOccsEntry;
  }

  std::fclose(file);
}

void Dict::insert(const std::string & element)
{
  if (element.size() > maxEntrySize)
    util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));

  elementsToIndexes.emplace(element, elementsToIndexes.size());
  while (nbOccs.size() < elementsToIndexes.size())
    nbOccs.emplace_back(0);
}

int Dict::getIndexOrInsert(const std::string & element)
{
  if (element.empty())
    return getIndexOrInsert(emptyValueStr);

  const auto & found = elementsToIndexes.find(element);

  if (found == elementsToIndexes.end())
  {
    if (state == State::Open)
    {
      insert(element);
      if (isCountingOccs)
        nbOccs[elementsToIndexes[element]]++;
      return elementsToIndexes[element];
    }
    if (isCountingOccs)
      nbOccs[elementsToIndexes[unknownValueStr]]++;
    return elementsToIndexes[unknownValueStr];
  }

  if (isCountingOccs)
    nbOccs[found->second]++;
  return found->second;
}

void Dict::setState(State state)
{
  this->state = state;
}

Dict::State Dict::getState() const
{
  return state;
}

void Dict::save(std::FILE * destination, Encoding encoding) const
{
  fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
  fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
  for (auto & it : elementsToIndexes)
    printEntry(destination, it.second, it.first, encoding);
}

bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
{
  if (encoding == Encoding::Ascii)
  {
    static std::string readFormat = "%d\t%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n";
    return fscanf(file, readFormat.c_str(), index, nbOccsEntry, entry) == 3;
  }
  else
  {
    if (std::fread(index, sizeof *index, 1, file) != 1)
      return false;
    if (std::fread(nbOccsEntry, sizeof *nbOccsEntry, 1, file) != 1)
      return false;
    for (unsigned int i = 0; i < maxEntrySize; i++)
    {
      if (std::fread(entry+i, 1, 1, file) != 1)
        return false;
      if (!entry[i])
        return true;
    }
    return false;
  }
}

void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
{
  auto entryNbOccs = getNbOccs(index);

  if (encoding == Encoding::Ascii)
  {
    static std::string printFormat = "%d\t%d\t%s\n";
    fprintf(file, printFormat.c_str(), index, entryNbOccs, entry.c_str());
  }
  else
  {
    std::fwrite(&index, sizeof index, 1, file);
    std::fwrite(&entryNbOccs, sizeof entryNbOccs, 1, file);
    std::fwrite(entry.c_str(), 1, entry.size()+1, file);
  }
}

void Dict::countOcc(bool isCountingOccs)
{
  this->isCountingOccs = isCountingOccs;
}

std::size_t Dict::size() const
{
  return elementsToIndexes.size();
}

int Dict::getNbOccs(int index) const
{
  if (index < 0 || index >= (int)nbOccs.size())
    return 0;
  return nbOccs[index];
}