From 57db2a2e15f62c7e0e7b627313ce99fa0dcab4df Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 31 Jul 2020 17:24:12 +0200 Subject: [PATCH] Changed the way prefix are handled in dicts --- common/include/Dict.hpp | 7 +- common/src/Dict.cpp | 65 ++++++++++++++----- torch_modules/include/ContextualModule.hpp | 2 +- torch_modules/include/Submodule.hpp | 2 +- torch_modules/src/ContextModule.cpp | 34 +++++----- torch_modules/src/ContextualModule.cpp | 51 ++++++++------- .../src/DepthLayerTreeEmbeddingModule.cpp | 4 +- torch_modules/src/DistanceModule.cpp | 8 ++- torch_modules/src/FocusedColumnModule.cpp | 15 +++-- torch_modules/src/HistoryModule.cpp | 6 +- torch_modules/src/RawInputModule.cpp | 10 +-- torch_modules/src/SplitTransModule.cpp | 4 +- torch_modules/src/StateNameModule.cpp | 2 +- torch_modules/src/Submodule.cpp | 12 ++-- 14 files changed, 137 insertions(+), 85 deletions(-) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index efd5806..dda547b 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -26,6 +26,7 @@ class Dict private : std::unordered_map<std::string, int> elementsToIndexes; + std::unordered_map<int, std::string> indexesToElements; std::vector<int> nbOccs; State state; bool isCountingOccs{false}; @@ -43,7 +44,8 @@ class Dict public : void countOcc(bool isCountingOccs); - int getIndexOrInsert(const std::string & element); + int getIndexOrInsert(const std::string & element, const std::string & prefix); + std::string getElement(std::size_t index); void setState(State state); State getState() const; void save(std::filesystem::path path, Encoding encoding) const; @@ -52,7 +54,8 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); - void loadWord2Vec(std::filesystem::path path); + void loadWord2Vec(std::filesystem::path path, std::string prefix); + bool isSpecialValue(const std::string & value); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 49e678f..d4e7ba2 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -42,6 +42,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' bad format", filename)); elementsToIndexes.reserve(nbEntries); + indexesToElements.reserve(nbEntries); int entryIndex; int nbOccsEntry; @@ -52,6 +53,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; + indexesToElements[entryIndex] = entryString; while ((int)nbOccs.size() <= entryIndex) nbOccs.emplace_back(0); nbOccs[entryIndex] = nbOccsEntry; @@ -66,37 +68,40 @@ void Dict::insert(const std::string & element) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); elementsToIndexes.emplace(element, elementsToIndexes.size()); + indexesToElements.emplace(elementsToIndexes.size()-1, element); while (nbOccs.size() < elementsToIndexes.size()) nbOccs.emplace_back(0); } -int Dict::getIndexOrInsert(const std::string & element) +int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix) { if (element.empty()) - return getIndexOrInsert(emptyValueStr); + return getIndexOrInsert(emptyValueStr, prefix); if (element.size() == 1 and util::isSeparator(util::utf8char(element))) - return getIndexOrInsert(separatorValueStr); + return getIndexOrInsert(separatorValueStr, prefix); if (util::isNumber(element)) - return getIndexOrInsert(numberValueStr); + return getIndexOrInsert(numberValueStr, prefix); if (util::isUrl(element)) - return getIndexOrInsert(urlValueStr); + return getIndexOrInsert(urlValueStr, prefix); - const auto & found = elementsToIndexes.find(element); + auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); + const auto & found = elementsToIndexes.find(prefixed); if (found == elementsToIndexes.end()) { if (state == State::Open) { - insert(element); + insert(prefixed); if (isCountingOccs) - nbOccs[elementsToIndexes[element]]++; - return elementsToIndexes[element]; + nbOccs[elementsToIndexes[prefixed]]++; + return elementsToIndexes[prefixed]; } - const auto & found2 = elementsToIndexes.find(util::lower(element)); + prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element)); + const auto & found2 = elementsToIndexes.find(prefixed); if (found2 != elementsToIndexes.end()) { if (isCountingOccs) @@ -104,9 +109,10 @@ int Dict::getIndexOrInsert(const std::string & element) return found2->second; } + prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr); if (isCountingOccs) - nbOccs[elementsToIndexes[unknownValueStr]]++; - return elementsToIndexes[unknownValueStr]; + nbOccs[elementsToIndexes[prefixed]]++; + return elementsToIndexes[prefixed]; } if (isCountingOccs) @@ -217,7 +223,7 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } -void Dict::loadWord2Vec(std::filesystem::path path) +void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) { if (path.empty()) return; @@ -235,6 +241,16 @@ void Dict::loadWord2Vec(std::filesystem::path path) try { + if (!prefix.empty()) + { + std::vector<std::string> toAdd; + for (auto & it : elementsToIndexes) + if (isSpecialValue(it.first)) + toAdd.emplace_back(fmt::format("{}({})", prefix, it.first)); + for (auto & elem : toAdd) + getIndexOrInsert(elem, ""); + } + while (!std::feof(file)) { if (buffer != std::fgets(buffer, 100000, file)) @@ -251,9 +267,13 @@ void Dict::loadWord2Vec(std::filesystem::path path) if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); - auto dictIndex = getIndexOrInsert(splited[0]); + if (splited[0] == "<unk>") + continue; + auto toInsert = util::splitAsUtf8(splited[0]); + toInsert.replace("◌", " "); + auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix); - if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr)) + if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix)) util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer)); } } catch (std::exception & e) @@ -269,3 +289,18 @@ void Dict::loadWord2Vec(std::filesystem::path path) setState(originalState); } +bool Dict::isSpecialValue(const std::string & value) +{ + return value == unknownValueStr + || value == nullValueStr + || value == emptyValueStr + || value == separatorValueStr + || value == numberValueStr + || value == urlValueStr; +} + +std::string Dict::getElement(std::size_t index) +{ + return indexesToElements[index]; +} + diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index d7e290c..0395c11 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -22,7 +22,7 @@ class ContextualModuleImpl : public Submodule int inSize; int outSize; std::filesystem::path path; - std::filesystem::path w2vFile; + std::filesystem::path w2vFiles; public : diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 70250e0..77c0346 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); - void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 66a6728..c83de18 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -54,9 +54,14 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin { auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) - getDict().loadWord2Vec(this->path / p); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + { + auto splited = util::split(p, ','); + if (splited.size() != 2) + util::myThrow("expected 'prefix,pretrained.w2v'"); + getDict().loadWord2Vec(this->path / splited[1], splited[0]); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} @@ -117,7 +122,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c if (index == -1) { for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr))); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else { @@ -126,23 +131,17 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { std::string value; if (config.isCommentPredicted(index)) - value = "ID(comment)"; + value = "comment"; else if (config.isMultiwordPredicted(index)) - value = "ID(multiword)"; + value = "multiword"; else if (config.isTokenPredicted(index)) - value = "ID(token)"; - dictIndex = dict.getIndexOrInsert(value); - } - else if (col == Config::EOSColName) - { - dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + value = "token"; + dictIndex = dict.getIndexOrInsert(value, col); } else { std::string featureValue = functions[colIndex](config.getAsFeature(col, index)); - if (w2vFiles.empty()) - featureValue = fmt::format("{}({})", col, featureValue); - dictIndex = dict.getIndexOrInsert(featureValue); + dictIndex = dict.getIndexOrInsert(featureValue, col); } for (auto & contextElement : context) @@ -165,6 +164,9 @@ void ContextModuleImpl::registerEmbeddings() wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) - loadPretrainedW2vEmbeddings(wordEmbeddings, path / p); + { + auto splited = util::split(p, ','); + loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index ebe386a..cc06903 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -53,13 +53,20 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); - w2vFile = sm.str(7); + w2vFiles = sm.str(7); - if (!w2vFile.empty()) + if (!w2vFiles.empty()) { - getDict().loadWord2Vec(this->path / w2vFile); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + if (splited.size() != 2) + util::myThrow("expected 'prefix,file.w2v'"); + getDict().loadWord2Vec(this->path / splited[1], splited[0]); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} @@ -127,17 +134,13 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context if (index == -1) { for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr))); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else if (index == -2) { + //TODO maybe change this to a unique value like Dict::noneValueStr for (auto & contextElement : context) - { - auto currentState = dict.getState(); - dict.setState(Dict::State::Open); - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_"))); - dict.setState(currentState); - } + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else { @@ -146,23 +149,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context { std::string value; if (config.isCommentPredicted(index)) - value = "ID(comment)"; + value = "comment"; else if (config.isMultiwordPredicted(index)) - value = "ID(multiword)"; + value = "multiword"; else if (config.isTokenPredicted(index)) - value = "ID(token)"; - dictIndex = dict.getIndexOrInsert(value); - } - else if (col == Config::EOSColName) - { - dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + value = "token"; + dictIndex = dict.getIndexOrInsert(value, col); } else { std::string featureValue = config.getAsFeature(col, index); - if (w2vFile.empty()) - featureValue = fmt::format("{}({})", col, featureValue); - dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue)); + dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col); } for (auto & contextElement : context) @@ -214,6 +211,12 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); + + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 2cb88dc..6d97fbe 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -117,9 +117,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon for (int i = 0; i < maxElemPerDepth[depth]; i++) for (auto & col : columns) if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])))); + contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col)); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } } } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 40098bc..daf7a3c 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -86,6 +86,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, else toIndexes.emplace_back(-1); + std::string prefix = "DISTANCE"; + for (auto & contextElement : context) { for (auto from : fromIndexes) @@ -93,16 +95,16 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, { if (from == -1 or to == -1) { - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); continue; } long dist = std::abs(config.getRelativeDistance(from, to)); if (dist <= threshold) - contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("distance({})", dist))); + contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), "")); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix)); } } } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 62c1aef..556fdc4 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -84,7 +84,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont if (index == -1) { for (int i = 0; i < maxNbElements; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column)); continue; } @@ -93,6 +93,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont { auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get())); + //TODO don't use nullValueStr here for (int i = 0; i < maxNbElements; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("{}", asUtf8[i])); @@ -105,23 +106,23 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont for (int i = 0; i < maxNbElements; i++) if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); + elements.emplace_back(splited[i]); else elements.emplace_back(Dict::nullValueStr); } else if (column == "ID") { if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); + elements.emplace_back("TOKEN"); else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); + elements.emplace_back("MULTIWORD"); else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); + elements.emplace_back("EMPTYNODE"); } else if (column == "EOS") { bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1; - elements.emplace_back(fmt::format("EOS({})", isEOS)); + elements.emplace_back(fmt::format("{}", isEOS)); } else { @@ -132,7 +133,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements)); for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element)); + contextElement.emplace_back(dict.getIndexOrInsert(element, column)); } } } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index eb5c28c..7249116 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -57,12 +57,14 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { auto & dict = getDict(); + std::string prefix = "HISTORY"; + for (auto & contextElement : context) for (int i = 0; i < maxNbElements; i++) if (config.hasHistory(i)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i))); + contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix)); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); } void HistoryModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 8f43a2f..d6adb74 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -57,20 +57,22 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, if (leftWindow < 0 or rightWindow < 0) return; + std::string prefix = "LETTER"; + auto & dict = getDict(); for (auto & contextElement : context) { for (int i = 0; i < leftWindow; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), "")); else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); for (int i = 0; i <= rightWindow; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), "")); else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); } } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index d4f6d84..43964c6 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -58,9 +58,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context for (auto & contextElement : context) for (int i = 0; i < maxNbTrans; i++) if (i < (int)splitTransitions.size()) - contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), "")); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, "")); } void SplitTransModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 42edd50..18627db 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -33,7 +33,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, { auto & dict = getDict(); for (auto & contextElement : context) - contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); + contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), "")); } void StateNameModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index e52ef5e..589bc96 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } -void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path) +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix) { if (path.empty()) return; @@ -44,12 +44,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); - auto dictIndex = getDict().getIndexOrInsert(splited[0]); + std::string word; + if (splited[0] == "<unk>") - dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr); + word = Dict::unknownValueStr; + else + word = splited[0]; - if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))) - continue; + auto dictIndex = getDict().getIndexOrInsert(word, prefix); if (embeddingsSize != splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); -- GitLab