Commit dfd75ada authored by Franck Dary's avatar Franck Dary
Browse files

Having separate wordEmbeddings for special values

parent b86af518
......@@ -45,6 +45,7 @@ class Dict
private :
void addPrefixValues(std::string prefix);
void readFromFile(const char * filename);
void insert(const std::string & element);
void reset();
......
......@@ -5,14 +5,7 @@ Dict::Dict(State state)
{
locked = false;
setState(state);
insert(unknownValueStr);
insert(nullValueStr);
insert(oobValueStr);
insert(noChildValueStr);
insert(emptyValueStr);
insert(numberValueStr);
insert(urlValueStr);
insert(separatorValueStr);
addPrefixValues("");
}
Dict::Dict(const char * filename, State state)
......@@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state)
locked = false;
}
void Dict::addPrefixValues(std::string prefix)
{
for (auto & element : {unknownValueStr, nullValueStr, oobValueStr, noChildValueStr, emptyValueStr, numberValueStr, urlValueStr, separatorValueStr})
{
std::string prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
if (!elementsToIndexes.count(prefixed))
insert(prefixed);
}
}
void Dict::lock()
{
locked = true;
......@@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename)
if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
std::string prefix = "";
auto splited = util::split(entryString, '(');
if (splited.size() > 1)
prefix = splited[0];
prefixes.insert(prefix);
if (elementsToIndexes.count(entryString))
util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
if (indexesToElements.count(entryIndex))
......@@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
if (state == State::Open)
elementsMutex.lock();
prefixes.insert(prefix);
int index = _getIndexOrInsert(element, prefix);
if (state == State::Open)
......@@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{
if (!prefixes.count(prefix))
{
prefixes.insert(prefix);
addPrefixValues(prefix);
}
if (element.empty())
return _getIndexOrInsert(emptyValueStr, prefix);
......
......@@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings(bool loadPretrained)
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
......
......@@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings(bool loadPretrained)
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
......
......@@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained)
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
......
......@@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
{
for (auto elem : specialIndexes)
{
if (elem >= specialIndexes.size())
util::error("Special indexes are not contiguous from zero.");
}
if (maxNorm == std::numeric_limits<float>::max())
{
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
......@@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0);
normalIndexes.index_put_({~mask}, 0);
return normalIndexes*normalRes + specialIndexes*specialRes;
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment