Select Git revision
-
Franck Dary authoredFranck Dary authored
Dict.cpp 9.14 KiB
#include "Dict.hpp"
#include "util.hpp"
Dict::Dict(State state)
{
setState(state);
insert(unknownValueStr);
insert(nullValueStr);
insert(oobValueStr);
insert(noChildValueStr);
insert(emptyValueStr);
insert(numberValueStr);
insert(urlValueStr);
insert(separatorValueStr);
}
Dict::Dict(const char * filename, State state)
{
readFromFile(filename);
setState(state);
}
void Dict::readFromFile(const char * filename)
{
reset();
std::FILE * file = std::fopen(filename, "r");
if (!file)
util::myThrow(fmt::format("could not open file '{}'", filename));
char buffer[1048];
if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
util::myThrow(fmt::format("file '{}' bad format", filename));
Encoding encoding{Encoding::Ascii};
if (std::string(buffer) == "Ascii")
encoding = Encoding::Ascii;
else if (std::string(buffer) == "Binary")
encoding = Encoding::Binary;
else
util::myThrow(fmt::format("file '{}' bad format", filename));
int nbEntries;
if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
util::myThrow(fmt::format("file '{}' bad format", filename));
elementsToIndexes.reserve(nbEntries);
indexesToElements.reserve(nbEntries);
int entryIndex;
int nbOccsEntry;
char entryString[maxEntrySize+1];
for (int i = 0; i < nbEntries; i++)
{
if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
if (elementsToIndexes.count(entryString))
util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
if (indexesToElements.count(entryIndex))
util::myThrow(fmt::format("index '{}' is already in dict", entryIndex));
elementsToIndexes[entryString] = entryIndex;
indexesToElements[entryIndex] = entryString;
while ((int)nbOccs.size() <= entryIndex)
nbOccs.emplace_back(0);
nbOccs[entryIndex] = nbOccsEntry;
}
std::fclose(file);
}
void Dict::insert(const std::string & element)
{
if (element.size() > maxEntrySize)
util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
if (elementsToIndexes.count(element))
util::myThrow(fmt::format("element '{}' already in dict", element));
elementsToIndexes.emplace(element, elementsToIndexes.size());
if (indexesToElements.count(elementsToIndexes.size()-1))
util::myThrow(fmt::format("index '{}' already in dict", elementsToIndexes.size()-1));
indexesToElements.emplace(elementsToIndexes.size()-1, element);
while (nbOccs.size() < elementsToIndexes.size())
nbOccs.emplace_back(0);
}
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (element.empty())
return getIndexOrInsert(emptyValueStr, prefix);
if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
{
return getIndexOrInsert(separatorValueStr, prefix);
}
if (util::isNumber(element))
return getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element))
return getIndexOrInsert(urlValueStr, prefix);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed);
if (found == elementsToIndexes.end())
{
if (state == State::Open)
{
insert(prefixed);
if (isCountingOccs)
nbOccs[elementsToIndexes.at(prefixed)]++;
return elementsToIndexes.at(prefixed);
}
prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element));
const auto & found2 = elementsToIndexes.find(prefixed);
if (found2 != elementsToIndexes.end())
{
if (isCountingOccs)
nbOccs[found2->second]++;
return found2->second;
}
prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr);
const auto & found3 = elementsToIndexes.find(prefixed);
if (found3 != elementsToIndexes.end())
{
if (isCountingOccs)
nbOccs[found3->second]++;
return found3->second;
}
return elementsToIndexes[unknownValueStr];
}
if (isCountingOccs)
nbOccs[found->second]++;
return found->second;
}
void Dict::setState(State state)
{
this->state = state;
}
Dict::State Dict::getState() const
{
return state;
}
void Dict::save(std::filesystem::path path, Encoding encoding) const
{
std::FILE * destination = std::fopen(path.c_str(), "w");
if (!destination)
util::myThrow(fmt::format("could not write file '{}'", path.string()));
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
for (auto & it : elementsToIndexes)
printEntry(destination, it.second, it.first, encoding);
std::fclose(destination);
}
bool Dict::readEntry(std::FILE * file, int * index, int * nbOccsEntry, char * entry, Encoding encoding)
{
if (encoding == Encoding::Ascii)
{
static std::string readFormat = "%d\t%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n";
return fscanf(file, readFormat.c_str(), index, nbOccsEntry, entry) == 3;
}
else
{
if (std::fread(index, sizeof *index, 1, file) != 1)
return false;
if (std::fread(nbOccsEntry, sizeof *nbOccsEntry, 1, file) != 1)
return false;
for (unsigned int i = 0; i < maxEntrySize; i++)
{
if (std::fread(entry+i, 1, 1, file) != 1)
return false;
if (!entry[i])
return true;
}
return false;
}
}
void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const
{
auto entryNbOccs = getNbOccs(index);
if (encoding == Encoding::Ascii)
{
static std::string printFormat = "%d\t%d\t%s\n";
fprintf(file, printFormat.c_str(), index, entryNbOccs, entry.c_str());
}
else
{
std::fwrite(&index, sizeof index, 1, file);
std::fwrite(&entryNbOccs, sizeof entryNbOccs, 1, file);
std::fwrite(entry.c_str(), 1, entry.size()+1, file);
}
}
void Dict::countOcc(bool isCountingOccs)
{
this->isCountingOccs = isCountingOccs;
}
std::size_t Dict::size() const
{
return elementsToIndexes.size();
}
int Dict::getNbOccs(int index) const
{
if (index < 0 || index >= (int)nbOccs.size())
return 0;
return nbOccs[index];
}
void Dict::removeRareElements()
{
int minNbOcc = std::numeric_limits<int>::max();
for (int nbOcc : nbOccs)
if (nbOcc < minNbOcc)
minNbOcc = nbOcc;
std::unordered_map<std::string, int> newElementsToIndexes;
std::vector<int> newNbOccs;
for (auto & it : elementsToIndexes)
if (nbOccs[it.second] > minNbOcc)
{
newElementsToIndexes.emplace(it.first, newElementsToIndexes.size());
newNbOccs.emplace_back(nbOccs[it.second]);
}
elementsToIndexes = newElementsToIndexes;
nbOccs = newNbOccs;
}
bool Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
{
if (path.empty())
return false;
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
auto originalState = getState();
setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
bool firstLine = true;
bool pretrained = false;
try
{
if (!prefix.empty())
{
std::vector<std::string> toAdd;
for (auto & it : elementsToIndexes)
if (isSpecialValue(it.first))
toAdd.emplace_back(fmt::format("{}({})", prefix, it.first));
for (auto & elem : toAdd)
getIndexOrInsert(elem, "");
}
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
break;
if (firstLine)
{
firstLine = false;
continue;
}
pretrained = true;
auto splited = util::split(util::strip(buffer), ' ');
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
if (splited[0] == "<unk>")
continue;
auto toInsert = util::splitAsUtf8(splited[0]);
toInsert.replace("◌", " ");
auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix);
if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix))
util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
}
} catch (std::exception & e)
{
util::myThrow(fmt::format("caught '{}'", e.what()));
}
std::fclose(file);
if (firstLine)
util::myThrow(fmt::format("file '{}' is empty", path.string()));
setState(originalState);
return pretrained;
}
bool Dict::isSpecialValue(const std::string & value)
{
return value == unknownValueStr
|| value == nullValueStr
|| value == oobValueStr
|| value == noChildValueStr
|| value == emptyValueStr
|| value == separatorValueStr
|| value == numberValueStr
|| value == urlValueStr;
}
std::string Dict::getElement(std::size_t index)
{
return indexesToElements[index];
}
void Dict::reset()
{
elementsToIndexes.clear();
indexesToElements.clear();
nbOccs.clear();
state = State::Closed;
isCountingOccs = false;
}