Skip to content
Snippets Groups Projects
Commit dfd75ada authored by Franck Dary's avatar Franck Dary
Browse files

Having separate wordEmbeddings for special values

parent b86af518
Branches
No related tags found
No related merge requests found
...@@ -45,6 +45,7 @@ class Dict ...@@ -45,6 +45,7 @@ class Dict
private : private :
void addPrefixValues(std::string prefix);
void readFromFile(const char * filename); void readFromFile(const char * filename);
void insert(const std::string & element); void insert(const std::string & element);
void reset(); void reset();
......
...@@ -5,14 +5,7 @@ Dict::Dict(State state) ...@@ -5,14 +5,7 @@ Dict::Dict(State state)
{ {
locked = false; locked = false;
setState(state); setState(state);
insert(unknownValueStr); addPrefixValues("");
insert(nullValueStr);
insert(oobValueStr);
insert(noChildValueStr);
insert(emptyValueStr);
insert(numberValueStr);
insert(urlValueStr);
insert(separatorValueStr);
} }
Dict::Dict(const char * filename, State state) Dict::Dict(const char * filename, State state)
...@@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state) ...@@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state)
locked = false; locked = false;
} }
void Dict::addPrefixValues(std::string prefix)
{
for (auto & element : {unknownValueStr, nullValueStr, oobValueStr, noChildValueStr, emptyValueStr, numberValueStr, urlValueStr, separatorValueStr})
{
std::string prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
if (!elementsToIndexes.count(prefixed))
insert(prefixed);
}
}
void Dict::lock() void Dict::lock()
{ {
locked = true; locked = true;
...@@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename) ...@@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename)
if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding)) if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
std::string prefix = "";
auto splited = util::split(entryString, '(');
if (splited.size() > 1)
prefix = splited[0];
prefixes.insert(prefix);
if (elementsToIndexes.count(entryString)) if (elementsToIndexes.count(entryString))
util::myThrow(fmt::format("entry '{}' is already in dict", entryString)); util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
if (indexesToElements.count(entryIndex)) if (indexesToElements.count(entryIndex))
...@@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref ...@@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
if (state == State::Open) if (state == State::Open)
elementsMutex.lock(); elementsMutex.lock();
prefixes.insert(prefix);
int index = _getIndexOrInsert(element, prefix); int index = _getIndexOrInsert(element, prefix);
if (state == State::Open) if (state == State::Open)
...@@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref ...@@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix) int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
{ {
if (!prefixes.count(prefix))
{
prefixes.insert(prefix);
addPrefixValues(prefix);
}
if (element.empty()) if (element.empty())
return _getIndexOrInsert(emptyValueStr, prefix); return _getIndexOrInsert(emptyValueStr, prefix);
......
...@@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) ...@@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings(bool loadPretrained) void ContextModuleImpl::registerEmbeddings(bool loadPretrained)
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
......
...@@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) ...@@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings(bool loadPretrained) void ContextualModuleImpl::registerEmbeddings(bool loadPretrained)
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
......
...@@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config ...@@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained) void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained)
{ {
if (!wordEmbeddings) if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
......
...@@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); ...@@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes) WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
{ {
for (auto elem : specialIndexes) for (auto elem : specialIndexes)
{
if (elem >= specialIndexes.size()) if (elem >= specialIndexes.size())
util::error("Special indexes are not contiguous from zero."); util::error("Special indexes are not contiguous from zero.");
}
if (maxNorm == std::numeric_limits<float>::max()) if (maxNorm == std::numeric_limits<float>::max())
{ {
normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
...@@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) ...@@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
specialIndexes.index_put_({mask}, 0); specialIndexes.index_put_({mask}, 0);
normalIndexes.index_put_({~mask}, 0); normalIndexes.index_put_({~mask}, 0);
return normalIndexes*normalRes + specialIndexes*specialRes; return normalIndexes*normalRes + specialIndexes*specialRes;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment