#include "Dict.hpp"
#include "util.hpp"

Dict::Dict(State state)
{
  setState(state);
  insert(unknownValueStr);
  insert(nullValueStr);
  insert(emptyValueStr);
  insert(numberValueStr);
  insert(urlValueStr);
}

Dict::Dict(const char * filename, State state)
{
  readFromFile(filename);
  setState(state);
}

void Dict::readFromFile(const char * filename)
{
  std::FILE * file = std::fopen(filename, "r");

  if (!file)
    util::myThrow(fmt::format("could not open file '{}'", filename));

  char buffer[1048];
  if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
    util::myThrow(fmt::format("file '{}' bad format", filename));

  Encoding encoding{Encoding::Ascii};
  if (std::string(buffer) == "Ascii")
    encoding = Encoding::Ascii;
  else if (std::string(buffer) == "Binary")
    encoding = Encoding::Binary;
  else
    util::myThrow(fmt::format("file '{}' bad format", filename));

  int nbEntries;

  if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
    util::myThrow(fmt::format("file '{}' bad format", filename));

  elementsToIndexes.reserve(nbEntries);

  int entryIndex;
  int nbOccsEntry;
  char entryString[maxEntrySize+1];
  for (int i = 0; i < nbEntries; i++)
  {
    if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
      util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));

    elementsToIndexes[entryString] = entryIndex;
    while ((int)nbOccs.size() <= entryIndex)
      nbOccs.emplace_back(0);
    nbOccs[entryIndex] = nbOccsEntry;
  }

  std::fclose(file);
}

void Dict::insert(const std::string & element)
{
  if (element.size() > maxEntrySize)
    util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));

  elementsToIndexes.emplace(element, elementsToIndexes.size());
  while (nbOccs.size() < elementsToIndexes.size())
    nbOccs.emplace_back(0);
}

int Dict::getIndexOrInsert(const std::string & element)
{
  if (element.empty())
    return getIndexOrInsert(emptyValueStr);

  if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
    return getIndexOrInsert(separatorValueStr);

  if (util::isNumber(element))
    return getIndexOrInsert(numberValueStr);

  if (util::isUrl(element))
    return getIndexOrInsert(urlValueStr);

  const auto & found = elementsToIndexes.find(element);

  if (found == elementsToIndexes.end())
  {
    if (state == State::Open)
    {
      insert(element);
      if (isCountingOccs)
        nbOccs[elementsToIndexes[element]]++;
      return elementsToIndexes[element];
    }

    const auto & found2 = elementsToIndexes.find(util::lower(element));
    if (found2 != elementsToIndexes.end())
    {
      if (isCountingOccs)
        nbOccs[found2->second]++;
      return found2->second;   
    }

    if (isCountingOccs)
      nbOccs[elementsToIndexes[unknownValueStr]]++;
    return elementsToIndexes[unknownValueStr];
  }

  if (isCountingOccs)
    nbOccs[found->second]++;
  return found->second;
}

void Dict::setState(State state)
{
  this->state = state;
}

Dict::State Dict::getState() const
{
  return state;
}

void Dict::save(std::filesystem::path path, Encoding encoding) const
{
  std::FILE * destination = std::fopen(path.c_str(), "w");
  if (!destination)
    util::myThrow(fmt::format("could not write file '{}'", path.string()));

  fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
  fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
  for (auto & it : elementsToIndexes)
    printEntry(destination, it.second, it.first, encoding);

  std::fclose(destination);
}

bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
{
  if (encoding == Encoding::Ascii)
  {
    static std::string readFormat = "%d\t%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n";
    return fscanf(file, readFormat.c_str(), index, nbOccsEntry, entry) == 3;
  }
  else
  {
    if (std::fread(index, sizeof *index, 1, file) != 1)
      return false;
    if (std::fread(nbOccsEntry, sizeof *nbOccsEntry, 1, file) != 1)
      return false;
    for (unsigned int i = 0; i < maxEntrySize; i++)
    {
      if (std::fread(entry+i, 1, 1, file) != 1)
        return false;
      if (!entry[i])
        return true;
    }
    return false;
  }
}

void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
{
  auto entryNbOccs = getNbOccs(index);

  if (encoding == Encoding::Ascii)
  {
    static std::string printFormat = "%d\t%d\t%s\n";
    fprintf(file, printFormat.c_str(), index, entryNbOccs, entry.c_str());
  }
  else
  {
    std::fwrite(&index, sizeof index, 1, file);
    std::fwrite(&entryNbOccs, sizeof entryNbOccs, 1, file);
    std::fwrite(entry.c_str(), 1, entry.size()+1, file);
  }
}

void Dict::countOcc(bool isCountingOccs)
{
  this->isCountingOccs = isCountingOccs;
}

std::size_t Dict::size() const
{
  return elementsToIndexes.size();
}

int Dict::getNbOccs(int index) const
{
  if (index < 0 || index >= (int)nbOccs.size())
    return 0;
  return nbOccs[index];
}

void Dict::removeRareElements()
{
  int minNbOcc = std::numeric_limits<int>::max();
  for (int nbOcc : nbOccs)
    if (nbOcc < minNbOcc)
      minNbOcc = nbOcc;

  std::unordered_map<std::string, int> newElementsToIndexes;
  std::vector<int> newNbOccs;

  for (auto & it : elementsToIndexes)
    if (nbOccs[it.second] > minNbOcc)
    {
      newElementsToIndexes.emplace(it.first, newElementsToIndexes.size());
      newNbOccs.emplace_back(nbOccs[it.second]);
    }

  elementsToIndexes = newElementsToIndexes;
  nbOccs = newNbOccs;
}

void Dict::loadWord2Vec(std::filesystem::path & path)
{
   if (path.empty())
    return;

  if (!std::filesystem::exists(path))
    util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));

  auto originalState = getState();
  setState(Dict::State::Open);

  std::FILE * file = std::fopen(path.c_str(), "r");
  char buffer[100000];

  bool firstLine = true;

  try
  {
    while (!std::feof(file))
    {
      if (buffer != std::fgets(buffer, 100000, file))
        break;

      if (firstLine)
      {
        firstLine = false;
        continue;
      }

      auto splited = util::split(util::strip(buffer), ' ');

      if (splited.size() < 2)
        util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));

      auto dictIndex = getIndexOrInsert(splited[0]);

      if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
        util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
    }
  } catch (std::exception & e)
  {
    util::myThrow(fmt::format("caught '{}'", e.what()));
  }

  std::fclose(file);

  if (firstLine)
    util::myThrow(fmt::format("file '{}' is empty", path.string()));

  setState(originalState);
}