diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index fa33fef26425e9765f4030b7f5da5b0321c8beed..d37ff6563bff8b95ac5258b18b83895b6ae7f484 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -45,6 +45,7 @@ class Dict private : + void addPrefixValues(std::string prefix); void readFromFile(const char * filename); void insert(const std::string & element); void reset(); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 2a4f1180c90186d4292c0b18e8abb1442ebe5283..82b0c8d3e4e74f0567e9e9e57534be5f2157d251 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -5,14 +5,7 @@ Dict::Dict(State state) { locked = false; setState(state); - insert(unknownValueStr); - insert(nullValueStr); - insert(oobValueStr); - insert(noChildValueStr); - insert(emptyValueStr); - insert(numberValueStr); - insert(urlValueStr); - insert(separatorValueStr); + addPrefixValues(""); } Dict::Dict(const char * filename, State state) @@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state) locked = false; } +void Dict::addPrefixValues(std::string prefix) +{ + for (auto & element : {unknownValueStr, nullValueStr, oobValueStr, noChildValueStr, emptyValueStr, numberValueStr, urlValueStr, separatorValueStr}) + { + std::string prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); + if (!elementsToIndexes.count(prefixed)) + + insert(prefixed); + } +} + void Dict::lock() { locked = true; @@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename) if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding)) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); + std::string prefix = ""; + auto splited = util::split(entryString, '('); + if (splited.size() > 1) + prefix = splited[0]; + prefixes.insert(prefix); if (elementsToIndexes.count(entryString)) util::myThrow(fmt::format("entry '{}' is already in dict", entryString)); if (indexesToElements.count(entryIndex)) @@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref if (state == State::Open) elementsMutex.lock(); - prefixes.insert(prefix); int index = _getIndexOrInsert(element, prefix); if (state == State::Open) @@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix) { + if (!prefixes.count(prefix)) + { + prefixes.insert(prefix); + addPrefixValues(prefix); + } if (element.empty()) return _getIndexOrInsert(emptyValueStr, prefix); diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 48f9a00755204874d9f044e0d6cbce359b2f92ed..ae92243b322a3d392fa27d6da985063f3ba7e1bd 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) void ContextModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 564c95f5ce4e51ff502b64103726cdcaff166429..3c8125858c67ec553abd3d13f88633e90af3b574 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 23ebe6f71eda68eff2cd2ca4029e599aa047b80a..1a0a9d32fe4129b022c6c361171f4542570d1566 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained) { if (!wordEmbeddings) - wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes())); + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes())); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) { diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp index 38a5e3bd738390755646d1431528aab0a14d75b0..045296fec76cfea88d31e12040bdc4e2ce865504 100644 --- a/torch_modules/src/WordEmbeddings.cpp +++ b/torch_modules/src/WordEmbeddings.cpp @@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max(); WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes) { for (auto elem : specialIndexes) + { if (elem >= specialIndexes.size()) util::error("Special indexes are not contiguous from zero."); + } if (maxNorm == std::numeric_limits<float>::max()) { normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq))); @@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input) specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice())); specialIndexes.index_put_({mask}, 0); normalIndexes.index_put_({~mask}, 0); + return normalIndexes*normalRes + specialIndexes*specialRes; }