From dfd75ada672f65f3717bd54b76330a3a011e1303 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 12 Nov 2021 20:56:32 +0100 Subject: [PATCH] Having separate wordEmbeddings for special values --- common/include/Dict.hpp | 1 + common/src/Dict.cpp | 31 ++++++++++++++++------- torch_modules/src/ContextModule.cpp | 2 +- torch_modules/src/ContextualModule.cpp | 2 +- torch_modules/src/FocusedColumnModule.cpp | 2 +- torch_modules/src/WordEmbeddings.cpp | 3 +++ 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index fa33fef..d37ff65 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -45,6 +45,7 @@ class Dict private : + void addPrefixValues(std::string prefix); void readFromFile(const char * filename); void insert(const std::string & element); void reset(); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 2a4f118..82b0c8d 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -5,14 +5,7 @@ Dict::Dict(State state) { locked = false; setState(state); - insert(unknownValueStr); - insert(nullValueStr); - insert(oobValueStr); - insert(noChildValueStr); - insert(emptyValueStr); - insert(numberValueStr); - insert(urlValueStr); - insert(separatorValueStr); + addPrefixValues(""); } Dict::Dict(const char * filename, State state) @@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state) locked = false; } +void Dict::addPrefixValues(std::string prefix) +{ + for (auto & element : {unknownValueStr, nullValueStr, oobValueStr, noChildValueStr, emptyValueStr, numberValueStr, urlValueStr, separatorValueStr}) + { + std::string prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); + if (!elementsToIndexes.count(prefixed)) + + insert(prefixed); + } +} + void Dict::lock() { locked = true; @@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename) if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding)) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); + std::string prefix = ""; + auto splited = util::split(entryString, '('); + if (splited.size() > 1) + prefix = splited[0]; + prefixes.insert(prefix); if (elementsToIndexes.count(entryString)) util::myThrow(fmt::format("entry '{}' is already in dict", entryString)); if (indexesToElements.count(entryIndex)) @@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref if (state == State::Open) elementsMutex.lock(); - prefixes.insert(prefix); int index = _getIndexOrInsert(element, prefix); if (state == State::Open) @@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix) { + if (!prefixes.count(prefix)) + { + prefixes.insert(prefix); + addPrefixValues(prefix); + } if (element.empty()) return _getIndexOrInsert(emptyValueStr, prefix); diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 48f9a00..ae92243 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 564c95f..3c81258 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 23ebe6f..1a0a9d3 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index 38a5e3b..045296f 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes) { for (auto elem : specialIndexes) + { if (elem >= specialIndexes.size()) util::error("Special indexes are not contiguous from zero."); + } if (maxNorm == std::numeric_limits<float>::max()) { normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); @@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); specialIndexes.index_put_({mask}, 0); normalIndexes.index_put_({~mask}, 0); + return normalIndexes*normalRes + specialIndexes*specialRes; } -- GitLab