Commit 57db2a2e authored by Franck Dary's avatar Franck Dary
Browse files

Changed the way prefix are handled in dicts

parent 4f4cd7c3
......@@ -26,6 +26,7 @@ class Dict
private :
std::unordered_map<std::string, int> elementsToIndexes;
std::unordered_map<int, std::string> indexesToElements;
std::vector<int> nbOccs;
State state;
bool isCountingOccs{false};
......@@ -43,7 +44,8 @@ class Dict
public :
void countOcc(bool isCountingOccs);
int getIndexOrInsert(const std::string & element);
int getIndexOrInsert(const std::string & element, const std::string & prefix);
std::string getElement(std::size_t index);
void setState(State state);
State getState() const;
void save(std::filesystem::path path, Encoding encoding) const;
......@@ -52,7 +54,8 @@ class Dict
std::size_t size() const;
int getNbOccs(int index) const;
void removeRareElements();
void loadWord2Vec(std::filesystem::path path);
void loadWord2Vec(std::filesystem::path path, std::string prefix);
bool isSpecialValue(const std::string & value);
};
#endif
......@@ -42,6 +42,7 @@ void Dict::readFromFile(const char * filename)
util::myThrow(fmt::format("file '{}' bad format", filename));
elementsToIndexes.reserve(nbEntries);
indexesToElements.reserve(nbEntries);
int entryIndex;
int nbOccsEntry;
......@@ -52,6 +53,7 @@ void Dict::readFromFile(const char * filename)
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
elementsToIndexes[entryString] = entryIndex;
indexesToElements[entryIndex] = entryString;
while ((int)nbOccs.size() <= entryIndex)
nbOccs.emplace_back(0);
nbOccs[entryIndex] = nbOccsEntry;
......@@ -66,37 +68,40 @@ void Dict::insert(const std::string & element)
util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
elementsToIndexes.emplace(element, elementsToIndexes.size());
indexesToElements.emplace(elementsToIndexes.size()-1, element);
while (nbOccs.size() < elementsToIndexes.size())
nbOccs.emplace_back(0);
}
int Dict::getIndexOrInsert(const std::string & element)
int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (element.empty())
return getIndexOrInsert(emptyValueStr);
return getIndexOrInsert(emptyValueStr, prefix);
if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
return getIndexOrInsert(separatorValueStr);
return getIndexOrInsert(separatorValueStr, prefix);
if (util::isNumber(element))
return getIndexOrInsert(numberValueStr);
return getIndexOrInsert(numberValueStr, prefix);
if (util::isUrl(element))
return getIndexOrInsert(urlValueStr);
return getIndexOrInsert(urlValueStr, prefix);
const auto & found = elementsToIndexes.find(element);
auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
const auto & found = elementsToIndexes.find(prefixed);
if (found == elementsToIndexes.end())
{
if (state == State::Open)
{
insert(element);
insert(prefixed);
if (isCountingOccs)
nbOccs[elementsToIndexes[element]]++;
return elementsToIndexes[element];
nbOccs[elementsToIndexes[prefixed]]++;
return elementsToIndexes[prefixed];
}
const auto & found2 = elementsToIndexes.find(util::lower(element));
prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element));
const auto & found2 = elementsToIndexes.find(prefixed);
if (found2 != elementsToIndexes.end())
{
if (isCountingOccs)
......@@ -104,9 +109,10 @@ int Dict::getIndexOrInsert(const std::string & element)
return found2->second;
}
prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr);
if (isCountingOccs)
nbOccs[elementsToIndexes[unknownValueStr]]++;
return elementsToIndexes[unknownValueStr];
nbOccs[elementsToIndexes[prefixed]]++;
return elementsToIndexes[prefixed];
}
if (isCountingOccs)
......@@ -217,7 +223,7 @@ void Dict::removeRareElements()
nbOccs = newNbOccs;
}
void Dict::loadWord2Vec(std::filesystem::path path)
void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
{
if (path.empty())
return;
......@@ -235,6 +241,16 @@ void Dict::loadWord2Vec(std::filesystem::path path)
try
{
if (!prefix.empty())
{
std::vector<std::string> toAdd;
for (auto & it : elementsToIndexes)
if (isSpecialValue(it.first))
toAdd.emplace_back(fmt::format("{}({})", prefix, it.first));
for (auto & elem : toAdd)
getIndexOrInsert(elem, "");
}
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
......@@ -251,9 +267,13 @@ void Dict::loadWord2Vec(std::filesystem::path path)
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
auto dictIndex = getIndexOrInsert(splited[0]);
if (splited[0] == "<unk>")
continue;
auto toInsert = util::splitAsUtf8(splited[0]);
toInsert.replace("◌", " ");
auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix);
if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix))
util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
}
} catch (std::exception & e)
......@@ -269,3 +289,18 @@ void Dict::loadWord2Vec(std::filesystem::path path)
setState(originalState);
}
bool Dict::isSpecialValue(const std::string & value)
{
return value == unknownValueStr
|| value == nullValueStr
|| value == emptyValueStr
|| value == separatorValueStr
|| value == numberValueStr
|| value == urlValueStr;
}
std::string Dict::getElement(std::size_t index)
{
return indexesToElements[index];
}
......@@ -22,7 +22,7 @@ class ContextualModuleImpl : public Submodule
int inSize;
int outSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
std::filesystem::path w2vFiles;
public :
......
......@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public :
void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
......
......@@ -54,9 +54,14 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
{
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
getDict().loadWord2Vec(this->path / p);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,pretrained.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......@@ -117,7 +122,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
}
else
{
......@@ -126,23 +131,17 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
{
std::string value;
if (config.isCommentPredicted(index))
value = "ID(comment)";
value = "comment";
else if (config.isMultiwordPredicted(index))
value = "ID(multiword)";
value = "multiword";
else if (config.isTokenPredicted(index))
value = "ID(token)";
dictIndex = dict.getIndexOrInsert(value);
}
else if (col == Config::EOSColName)
{
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
value = "token";
dictIndex = dict.getIndexOrInsert(value, col);
}
else
{
std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
if (w2vFiles.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(featureValue);
dictIndex = dict.getIndexOrInsert(featureValue, col);
}
for (auto & contextElement : context)
......@@ -165,6 +164,9 @@ void ContextModuleImpl::registerEmbeddings()
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
}
......@@ -53,13 +53,20 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
w2vFile = sm.str(7);
w2vFiles = sm.str(7);
if (!w2vFile.empty())
if (!w2vFiles.empty())
{
getDict().loadWord2Vec(this->path / w2vFile);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
if (splited.size() != 2)
util::myThrow("expected 'prefix,file.w2v'");
getDict().loadWord2Vec(this->path / splited[1], splited[0]);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
}
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......@@ -127,17 +134,13 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
}
else if (index == -2)
{
//TODO maybe change this to a unique value like Dict::noneValueStr
for (auto & contextElement : context)
{
auto currentState = dict.getState();
dict.setState(Dict::State::Open);
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_")));
dict.setState(currentState);
}
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
}
else
{
......@@ -146,23 +149,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
{
std::string value;
if (config.isCommentPredicted(index))
value = "ID(comment)";
value = "comment";
else if (config.isMultiwordPredicted(index))
value = "ID(multiword)";
value = "multiword";
else if (config.isTokenPredicted(index))
value = "ID(token)";
dictIndex = dict.getIndexOrInsert(value);
}
else if (col == Config::EOSColName)
{
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
value = "token";
dictIndex = dict.getIndexOrInsert(value, col);
}
else
{
std::string featureValue = config.getAsFeature(col, index);
if (w2vFile.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
}
for (auto & contextElement : context)
......@@ -214,6 +211,12 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
}
}
......@@ -117,9 +117,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
for (int i = 0; i < maxElemPerDepth[depth]; i++)
for (auto & col : columns)
if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i]))));
contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
}
}
}
......
......@@ -86,6 +86,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
else
toIndexes.emplace_back(-1);
std::string prefix = "DISTANCE";
for (auto & contextElement : context)
{
for (auto from : fromIndexes)
......@@ -93,16 +95,16 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
{
if (from == -1 or to == -1)
{
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
continue;
}
long dist = std::abs(config.getRelativeDistance(from, to));
if (dist <= threshold)
contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("distance({})", dist)));
contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), ""));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix));
}
}
}
......
......@@ -84,7 +84,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
if (index == -1)
{
for (int i = 0; i < maxNbElements; i++)
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column));
continue;
}
......@@ -93,6 +93,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
{
auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
//TODO don't use nullValueStr here
for (int i = 0; i < maxNbElements; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
......@@ -105,23 +106,23 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
for (int i = 0; i < maxNbElements; i++)
if (i < (int)splited.size())
elements.emplace_back(fmt::format("FEATS({})", splited[i]));
elements.emplace_back(splited[i]);
else
elements.emplace_back(Dict::nullValueStr);
}
else if (column == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
elements.emplace_back("TOKEN");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
elements.emplace_back("MULTIWORD");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
elements.emplace_back("EMPTYNODE");
}
else if (column == "EOS")
{
bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
elements.emplace_back(fmt::format("EOS({})", isEOS));
elements.emplace_back(fmt::format("{}", isEOS));
}
else
{
......@@ -132,7 +133,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
for (auto & element : elements)
contextElement.emplace_back(dict.getIndexOrInsert(element));
contextElement.emplace_back(dict.getIndexOrInsert(element, column));
}
}
}
......
......@@ -57,12 +57,14 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
{
auto & dict = getDict();
std::string prefix = "HISTORY";
for (auto & contextElement : context)
for (int i = 0; i < maxNbElements; i++)
if (config.hasHistory(i))
contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i)));
contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
}
void HistoryModuleImpl::registerEmbeddings()
......
......@@ -57,20 +57,22 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
if (leftWindow < 0 or rightWindow < 0)
return;
std::string prefix = "LETTER";
auto & dict = getDict();
for (auto & contextElement : context)
{
for (int i = 0; i < leftWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), ""));
else
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
for (int i = 0; i <= rightWindow; i++)
if (config.hasCharacter(config.getCharacterIndex()+i))
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), ""));
else
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
}
}
......
......@@ -58,9 +58,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
for (auto & contextElement : context)
for (int i = 0; i < maxNbTrans; i++)
if (i < (int)splitTransitions.size())
contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), ""));
else
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, ""));
}
void SplitTransModuleImpl::registerEmbeddings()
......
......@@ -33,7 +33,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
{
auto & dict = getDict();
for (auto & contextElement : context)
contextElement.emplace_back(dict.getIndexOrInsert(config.getState()));
contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), ""));
}
void StateNameModuleImpl::registerEmbeddings()
......
......@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
this->firstInputIndex = firstInputIndex;
}
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path)
void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix)
{
if (path.empty())
return;
......@@ -44,12 +44,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
auto dictIndex = getDict().getIndexOrInsert(splited[0]);
std::string word;
if (splited[0] == "<unk>")
dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr);
word = Dict::unknownValueStr;
else
word = splited[0];
if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)))
continue;
auto dictIndex = getDict().getIndexOrInsert(word, prefix);
if (embeddingsSize != splited.size()-1)
util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment