Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
  • loss
  • producer
3 results

Dict.cpp

Blame
  • Dict.cpp 9.14 KiB
    #include "Dict.hpp"
    #include "util.hpp"
    
    Dict::Dict(State state)
    {
      setState(state);
      insert(unknownValueStr);
      insert(nullValueStr);
      insert(oobValueStr);
      insert(noChildValueStr);
      insert(emptyValueStr);
      insert(numberValueStr);
      insert(urlValueStr);
      insert(separatorValueStr);
    }
    
    Dict::Dict(const char * filename, State state)
    {
      readFromFile(filename);
      setState(state);
    }
    
    void Dict::readFromFile(const char * filename)
    {
      reset();
    
      std::FILE * file = std::fopen(filename, "r");
    
      if (!file)
        util::myThrow(fmt::format("could not open file '{}'", filename));
    
      char buffer[1048];
      if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
        util::myThrow(fmt::format("file '{}' bad format", filename));
    
      Encoding encoding{Encoding::Ascii};
      if (std::string(buffer) == "Ascii")
        encoding = Encoding::Ascii;
      else if (std::string(buffer) == "Binary")
        encoding = Encoding::Binary;
      else
        util::myThrow(fmt::format("file '{}' bad format", filename));
    
      int nbEntries;
    
      if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
        util::myThrow(fmt::format("file '{}' bad format", filename));
    
      elementsToIndexes.reserve(nbEntries);
      indexesToElements.reserve(nbEntries);
    
      int entryIndex;
      int nbOccsEntry;
      char entryString[maxEntrySize+1];
      for (int i = 0; i < nbEntries; i++)
      {
        if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
          util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
    
        if (elementsToIndexes.count(entryString))
          util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
        if (indexesToElements.count(entryIndex))
          util::myThrow(fmt::format("index '{}' is already in dict", entryIndex));
        elementsToIndexes[entryString] = entryIndex;
        indexesToElements[entryIndex] = entryString;
        while ((int)nbOccs.size() <= entryIndex)
          nbOccs.emplace_back(0);
        nbOccs[entryIndex] = nbOccsEntry;
      }
    
      std::fclose(file);
    }
    
    void Dict::insert(const std::string & element)
    {
      if (element.size() > maxEntrySize)
        util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
    
      if (elementsToIndexes.count(element))
        util::myThrow(fmt::format("element '{}' already in dict", element));
    
      elementsToIndexes.emplace(element, elementsToIndexes.size());
    
      if (indexesToElements.count(elementsToIndexes.size()-1))
        util::myThrow(fmt::format("index '{}' already in dict", elementsToIndexes.size()-1));
    
      indexesToElements.emplace(elementsToIndexes.size()-1, element);
      while (nbOccs.size() < elementsToIndexes.size())
        nbOccs.emplace_back(0);
    }
    
    int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
    {
      if (element.empty())
        return getIndexOrInsert(emptyValueStr, prefix);
    
      if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
      {
        return getIndexOrInsert(separatorValueStr, prefix);
      }
    
      if (util::isNumber(element))
        return getIndexOrInsert(numberValueStr, prefix);
    
      if (util::isUrl(element))
        return getIndexOrInsert(urlValueStr, prefix);
    
      auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
      const auto & found = elementsToIndexes.find(prefixed);
    
      if (found == elementsToIndexes.end())
      {
        if (state == State::Open)
        {
          insert(prefixed);
          if (isCountingOccs)
            nbOccs[elementsToIndexes.at(prefixed)]++;
          return elementsToIndexes.at(prefixed);
        }
    
        prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element));
        const auto & found2 = elementsToIndexes.find(prefixed);
        if (found2 != elementsToIndexes.end())
        {
          if (isCountingOccs)
            nbOccs[found2->second]++;
          return found2->second;   
        }
    
        prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr);
    
        const auto & found3 = elementsToIndexes.find(prefixed);
        if (found3 != elementsToIndexes.end())
        {
          if (isCountingOccs)
            nbOccs[found3->second]++;
          return found3->second;   
        }
    
        return elementsToIndexes[unknownValueStr];
      }
    
      if (isCountingOccs)
        nbOccs[found->second]++;
      return found->second;
    }
    
    void Dict::setState(State state)
    {
      this->state = state;
    }
    
    Dict::State Dict::getState() const
    {
      return state;
    }
    
    void Dict::save(std::filesystem::path path, Encoding encoding) const
    {
      std::FILE * destination = std::fopen(path.c_str(), "w");
      if (!destination)
        util::myThrow(fmt::format("could not write file '{}'", path.string()));
    
      fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
      fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
      for (auto & it : elementsToIndexes)
        printEntry(destination, it.second, it.first, encoding);
    
      std::fclose(destination);
    }
    
    bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
    {
      if (encoding == Encoding::Ascii)
      {
        static std::string readFormat = "%d\t%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n";
        return fscanf(file, readFormat.c_str(), index, nbOccsEntry, entry) == 3;
      }
      else
      {
        if (std::fread(index, sizeof *index, 1, file) != 1)
          return false;
        if (std::fread(nbOccsEntry, sizeof *nbOccsEntry, 1, file) != 1)
          return false;
        for (unsigned int i = 0; i < maxEntrySize; i++)
        {
          if (std::fread(entry+i, 1, 1, file) != 1)
            return false;
          if (!entry[i])
            return true;
        }
        return false;
      }
    }
    
    void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
    {
      auto entryNbOccs = getNbOccs(index);
    
      if (encoding == Encoding::Ascii)
      {
        static std::string printFormat = "%d\t%d\t%s\n";
        fprintf(file, printFormat.c_str(), index, entryNbOccs, entry.c_str());
      }
      else
      {
        std::fwrite(&index, sizeof index, 1, file);
        std::fwrite(&entryNbOccs, sizeof entryNbOccs, 1, file);
        std::fwrite(entry.c_str(), 1, entry.size()+1, file);
      }
    }
    
    void Dict::countOcc(bool isCountingOccs)
    {
      this->isCountingOccs = isCountingOccs;
    }
    
    std::size_t Dict::size() const
    {
      return elementsToIndexes.size();
    }
    
    int Dict::getNbOccs(int index) const
    {
      if (index < 0 || index >= (int)nbOccs.size())
        return 0;
      return nbOccs[index];
    }
    
    void Dict::removeRareElements()
    {
      int minNbOcc = std::numeric_limits<int>::max();
      for (int nbOcc : nbOccs)
        if (nbOcc < minNbOcc)
          minNbOcc = nbOcc;
    
      std::unordered_map<std::string, int> newElementsToIndexes;
      std::vector<int> newNbOccs;
    
      for (auto & it : elementsToIndexes)
        if (nbOccs[it.second] > minNbOcc)
        {
          newElementsToIndexes.emplace(it.first, newElementsToIndexes.size());
          newNbOccs.emplace_back(nbOccs[it.second]);
        }
    
      elementsToIndexes = newElementsToIndexes;
      nbOccs = newNbOccs;
    }
    
    bool Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
    {
       if (path.empty())
        return false;
    
      if (!std::filesystem::exists(path))
        util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
    
      auto originalState = getState();
      setState(Dict::State::Open);
    
      std::FILE * file = std::fopen(path.c_str(), "r");
      char buffer[100000];
    
      bool firstLine = true;
      bool pretrained = false;
    
      try
      {
        if (!prefix.empty())
        {
          std::vector<std::string> toAdd;
          for (auto & it : elementsToIndexes)
            if (isSpecialValue(it.first))
              toAdd.emplace_back(fmt::format("{}({})", prefix, it.first));
          for (auto & elem : toAdd)
            getIndexOrInsert(elem, "");
        }
    
        while (!std::feof(file))
        {
          if (buffer != std::fgets(buffer, 100000, file))
            break;
    
          if (firstLine)
          {
            firstLine = false;
            continue;
          }
    
          pretrained = true;
          auto splited = util::split(util::strip(buffer), ' ');
    
          if (splited.size() < 2)
            util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
    
          if (splited[0] == "<unk>")
            continue;
          auto toInsert = util::splitAsUtf8(splited[0]);
          toInsert.replace("◌", " ");
          auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix);
    
          if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix))
            util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
        }
      } catch (std::exception & e)
      {
        util::myThrow(fmt::format("caught '{}'", e.what()));
      }
    
      std::fclose(file);
    
      if (firstLine)
        util::myThrow(fmt::format("file '{}' is empty", path.string()));
    
      setState(originalState);
    
      return pretrained;
    }
    
    bool Dict::isSpecialValue(const std::string & value)
    {
      return value == unknownValueStr
      || value == nullValueStr
      || value == oobValueStr
      || value == noChildValueStr
      || value == emptyValueStr
      || value == separatorValueStr
      || value == numberValueStr
      || value == urlValueStr;
    }
    
    std::string Dict::getElement(std::size_t index)
    {
      return indexesToElements[index];
    }
    
    void Dict::reset()
    {
      elementsToIndexes.clear();
      indexesToElements.clear();
      nbOccs.clear();
      state = State::Closed;
      isCountingOccs = false;
    }