#include "Dict.hpp" #include "util.hpp" Dict::Dict(State state) { setState(state); insert(unknownValueStr); insert(nullValueStr); insert(emptyValueStr); insert(numberValueStr); insert(urlValueStr); } Dict::Dict(const char * filename, State state) { readFromFile(filename); setState(state); } void Dict::readFromFile(const char * filename) { std::FILE * file = std::fopen(filename, "r"); if (!file) util::myThrow(fmt::format("could not open file '{}'", filename)); char buffer[1048]; if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1) util::myThrow(fmt::format("file '{}' bad format", filename)); Encoding encoding{Encoding::Ascii}; if (std::string(buffer) == "Ascii") encoding = Encoding::Ascii; else if (std::string(buffer) == "Binary") encoding = Encoding::Binary; else util::myThrow(fmt::format("file '{}' bad format", filename)); int nbEntries; if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1) util::myThrow(fmt::format("file '{}' bad format", filename)); elementsToIndexes.reserve(nbEntries); int entryIndex; int nbOccsEntry; char entryString[maxEntrySize+1]; for (int i = 0; i < nbEntries; i++) { if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding)) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; while ((int)nbOccs.size() <= entryIndex) nbOccs.emplace_back(0); nbOccs[entryIndex] = nbOccsEntry; } std::fclose(file); } void Dict::insert(const std::string & element) { if (element.size() > maxEntrySize) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); elementsToIndexes.emplace(element, elementsToIndexes.size()); while (nbOccs.size() < elementsToIndexes.size()) nbOccs.emplace_back(0); } int Dict::getIndexOrInsert(const std::string & element) { if (element.empty()) return getIndexOrInsert(emptyValueStr); if (element.size() == 1 and util::isSeparator(util::utf8char(element))) return getIndexOrInsert(separatorValueStr); if (util::isNumber(element)) return getIndexOrInsert(numberValueStr); if (util::isUrl(element)) return getIndexOrInsert(urlValueStr); const auto & found = elementsToIndexes.find(element); if (found == elementsToIndexes.end()) { if (state == State::Open) { insert(element); if (isCountingOccs) nbOccs[elementsToIndexes[element]]++; return elementsToIndexes[element]; } const auto & found2 = elementsToIndexes.find(util::lower(element)); if (found2 != elementsToIndexes.end()) { if (isCountingOccs) nbOccs[found2->second]++; return found2->second; } if (isCountingOccs) nbOccs[elementsToIndexes[unknownValueStr]]++; return elementsToIndexes[unknownValueStr]; } if (isCountingOccs) nbOccs[found->second]++; return found->second; } void Dict::setState(State state) { this->state = state; } Dict::State Dict::getState() const { return state; } void Dict::save(std::filesystem::path path, Encoding encoding) const { std::FILE * destination = std::fopen(path.c_str(), "w"); if (!destination) util::myThrow(fmt::format("could not write file '{}'", path.string())); fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size()); for (auto & it : elementsToIndexes) printEntry(destination, it.second, it.first, encoding); std::fclose(destination); } bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding) { if (encoding == Encoding::Ascii) { static std::string readFormat = "%d\t%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n"; return fscanf(file, readFormat.c_str(), index, nbOccsEntry, entry) == 3; } else { if (std::fread(index, sizeof *index, 1, file) != 1) return false; if (std::fread(nbOccsEntry, sizeof *nbOccsEntry, 1, file) != 1) return false; for (unsigned int i = 0; i < maxEntrySize; i++) { if (std::fread(entry+i, 1, 1, file) != 1) return false; if (!entry[i]) return true; } return false; } } void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const { auto entryNbOccs = getNbOccs(index); if (encoding == Encoding::Ascii) { static std::string printFormat = "%d\t%d\t%s\n"; fprintf(file, printFormat.c_str(), index, entryNbOccs, entry.c_str()); } else { std::fwrite(&index, sizeof index, 1, file); std::fwrite(&entryNbOccs, sizeof entryNbOccs, 1, file); std::fwrite(entry.c_str(), 1, entry.size()+1, file); } } void Dict::countOcc(bool isCountingOccs) { this->isCountingOccs = isCountingOccs; } std::size_t Dict::size() const { return elementsToIndexes.size(); } int Dict::getNbOccs(int index) const { if (index < 0 || index >= (int)nbOccs.size()) return 0; return nbOccs[index]; } void Dict::removeRareElements() { int minNbOcc = std::numeric_limits<int>::max(); for (int nbOcc : nbOccs) if (nbOcc < minNbOcc) minNbOcc = nbOcc; std::unordered_map<std::string, int> newElementsToIndexes; std::vector<int> newNbOccs; for (auto & it : elementsToIndexes) if (nbOccs[it.second] > minNbOcc) { newElementsToIndexes.emplace(it.first, newElementsToIndexes.size()); newNbOccs.emplace_back(nbOccs[it.second]); } elementsToIndexes = newElementsToIndexes; nbOccs = newNbOccs; } void Dict::loadWord2Vec(std::filesystem::path & path) { if (path.empty()) return; if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); auto originalState = getState(); setState(Dict::State::Open); std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[100000]; bool firstLine = true; try { while (!std::feof(file)) { if (buffer != std::fgets(buffer, 100000, file)) break; if (firstLine) { firstLine = false; continue; } auto splited = util::split(util::strip(buffer), ' '); if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); auto dictIndex = getIndexOrInsert(splited[0]); if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr)) util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer)); } } catch (std::exception & e) { util::myThrow(fmt::format("caught '{}'", e.what())); } std::fclose(file); if (firstLine) util::myThrow(fmt::format("file '{}' is empty", path.string())); setState(originalState); }