diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index efd580624f679b84c3e5e31498ed7f3af2269a35..dda547b91971a7d643523383b51db00343b90488 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -26,6 +26,7 @@ class Dict private : std::unordered_map<std::string, int> elementsToIndexes; + std::unordered_map<int, std::string> indexesToElements; std::vector<int> nbOccs; State state; bool isCountingOccs{false}; @@ -43,7 +44,8 @@ class Dict public : void countOcc(bool isCountingOccs); - int getIndexOrInsert(const std::string & element); + int getIndexOrInsert(const std::string & element, const std::string & prefix); + std::string getElement(std::size_t index); void setState(State state); State getState() const; void save(std::filesystem::path path, Encoding encoding) const; @@ -52,7 +54,8 @@ class Dict std::size_t size() const; int getNbOccs(int index) const; void removeRareElements(); - void loadWord2Vec(std::filesystem::path path); + void loadWord2Vec(std::filesystem::path path, std::string prefix); + bool isSpecialValue(const std::string & value); }; #endif diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 49e678f53329262cf8744457c2a02a225c5a249e..d4e7ba2ff7073d5cd94f3c1cb5ed3ad4d36221f7 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -42,6 +42,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' bad format", filename)); elementsToIndexes.reserve(nbEntries); + indexesToElements.reserve(nbEntries); int entryIndex; int nbOccsEntry; @@ -52,6 +53,7 @@ void Dict::readFromFile(const char * filename) util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; + indexesToElements[entryIndex] = entryString; while ((int)nbOccs.size() <= entryIndex) nbOccs.emplace_back(0); nbOccs[entryIndex] = nbOccsEntry; @@ -66,37 +68,40 @@ void Dict::insert(const std::string & element) util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); elementsToIndexes.emplace(element, elementsToIndexes.size()); + indexesToElements.emplace(elementsToIndexes.size()-1, element); while (nbOccs.size() < elementsToIndexes.size()) nbOccs.emplace_back(0); } -int Dict::getIndexOrInsert(const std::string & element) +int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix) { if (element.empty()) - return getIndexOrInsert(emptyValueStr); + return getIndexOrInsert(emptyValueStr, prefix); if (element.size() == 1 and util::isSeparator(util::utf8char(element))) - return getIndexOrInsert(separatorValueStr); + return getIndexOrInsert(separatorValueStr, prefix); if (util::isNumber(element)) - return getIndexOrInsert(numberValueStr); + return getIndexOrInsert(numberValueStr, prefix); if (util::isUrl(element)) - return getIndexOrInsert(urlValueStr); + return getIndexOrInsert(urlValueStr, prefix); - const auto & found = elementsToIndexes.find(element); + auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element); + const auto & found = elementsToIndexes.find(prefixed); if (found == elementsToIndexes.end()) { if (state == State::Open) { - insert(element); + insert(prefixed); if (isCountingOccs) - nbOccs[elementsToIndexes[element]]++; - return elementsToIndexes[element]; + nbOccs[elementsToIndexes[prefixed]]++; + return elementsToIndexes[prefixed]; } - const auto & found2 = elementsToIndexes.find(util::lower(element)); + prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element)); + const auto & found2 = elementsToIndexes.find(prefixed); if (found2 != elementsToIndexes.end()) { if (isCountingOccs) @@ -104,9 +109,10 @@ int Dict::getIndexOrInsert(const std::string & element) return found2->second; } + prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr); if (isCountingOccs) - nbOccs[elementsToIndexes[unknownValueStr]]++; - return elementsToIndexes[unknownValueStr]; + nbOccs[elementsToIndexes[prefixed]]++; + return elementsToIndexes[prefixed]; } if (isCountingOccs) @@ -217,7 +223,7 @@ void Dict::removeRareElements() nbOccs = newNbOccs; } -void Dict::loadWord2Vec(std::filesystem::path path) +void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix) { if (path.empty()) return; @@ -235,6 +241,16 @@ void Dict::loadWord2Vec(std::filesystem::path path) try { + if (!prefix.empty()) + { + std::vector<std::string> toAdd; + for (auto & it : elementsToIndexes) + if (isSpecialValue(it.first)) + toAdd.emplace_back(fmt::format("{}({})", prefix, it.first)); + for (auto & elem : toAdd) + getIndexOrInsert(elem, ""); + } + while (!std::feof(file)) { if (buffer != std::fgets(buffer, 100000, file)) @@ -251,9 +267,13 @@ void Dict::loadWord2Vec(std::filesystem::path path) if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); - auto dictIndex = getIndexOrInsert(splited[0]); + if (splited[0] == "<unk>") + continue; + auto toInsert = util::splitAsUtf8(splited[0]); + toInsert.replace("◌", " "); + auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix); - if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr)) + if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix)) util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer)); } } catch (std::exception & e) @@ -269,3 +289,18 @@ void Dict::loadWord2Vec(std::filesystem::path path) setState(originalState); } +bool Dict::isSpecialValue(const std::string & value) +{ + return value == unknownValueStr + || value == nullValueStr + || value == emptyValueStr + || value == separatorValueStr + || value == numberValueStr + || value == urlValueStr; +} + +std::string Dict::getElement(std::size_t index) +{ + return indexesToElements[index]; +} + diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index d7e290cc066fa2a1a842b02f702c5c717390e062..0395c11f78a987f0e88724f2aa82d978587bf7be 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -22,7 +22,7 @@ class ContextualModuleImpl : public Submodule int inSize; int outSize; std::filesystem::path path; - std::filesystem::path w2vFile; + std::filesystem::path w2vFiles; public : diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 70250e0a43f50468c05786d90884bd0ae7f7e1c1..77c03468aafe5be7998fa27be387a5e3c6bd4600 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde public : void setFirstInputIndex(std::size_t firstInputIndex); - void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path); + void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 66a672869171c4a1090118fb6bebc9375f91d653..c83de1831d570f50533ff25adced364115bb0270 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -54,9 +54,14 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin { auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) - getDict().loadWord2Vec(this->path / p); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + { + auto splited = util::split(p, ','); + if (splited.size() != 2) + util::myThrow("expected 'prefix,pretrained.w2v'"); + getDict().loadWord2Vec(this->path / splited[1], splited[0]); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} @@ -117,7 +122,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c if (index == -1) { for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr))); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else { @@ -126,23 +131,17 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { std::string value; if (config.isCommentPredicted(index)) - value = "ID(comment)"; + value = "comment"; else if (config.isMultiwordPredicted(index)) - value = "ID(multiword)"; + value = "multiword"; else if (config.isTokenPredicted(index)) - value = "ID(token)"; - dictIndex = dict.getIndexOrInsert(value); - } - else if (col == Config::EOSColName) - { - dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + value = "token"; + dictIndex = dict.getIndexOrInsert(value, col); } else { std::string featureValue = functions[colIndex](config.getAsFeature(col, index)); - if (w2vFiles.empty()) - featureValue = fmt::format("{}({})", col, featureValue); - dictIndex = dict.getIndexOrInsert(featureValue); + dictIndex = dict.getIndexOrInsert(featureValue, col); } for (auto & contextElement : context) @@ -165,6 +164,9 @@ void ContextModuleImpl::registerEmbeddings() wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); auto pathes = util::split(w2vFiles.string(), ' '); for (auto & p : pathes) - loadPretrainedW2vEmbeddings(wordEmbeddings, path / p); + { + auto splited = util::split(p, ','); + loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index ebe386a723f69159b34206ad9787727987448aa0..cc0690322b0c0d1c920846d01018156e8ecdbeb7 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -53,13 +53,20 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); - w2vFile = sm.str(7); + w2vFiles = sm.str(7); - if (!w2vFile.empty()) + if (!w2vFiles.empty()) { - getDict().loadWord2Vec(this->path / w2vFile); - getDict().setState(Dict::State::Closed); - dictSetPretrained(true); + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + if (splited.size() != 2) + util::myThrow("expected 'prefix,file.w2v'"); + getDict().loadWord2Vec(this->path / splited[1], splited[0]); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } } } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} @@ -127,17 +134,13 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context if (index == -1) { for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr))); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else if (index == -2) { + //TODO maybe change this to a unique value like Dict::noneValueStr for (auto & contextElement : context) - { - auto currentState = dict.getState(); - dict.setState(Dict::State::Open); - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_"))); - dict.setState(currentState); - } + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } else { @@ -146,23 +149,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context { std::string value; if (config.isCommentPredicted(index)) - value = "ID(comment)"; + value = "comment"; else if (config.isMultiwordPredicted(index)) - value = "ID(multiword)"; + value = "multiword"; else if (config.isTokenPredicted(index)) - value = "ID(token)"; - dictIndex = dict.getIndexOrInsert(value); - } - else if (col == Config::EOSColName) - { - dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + value = "token"; + dictIndex = dict.getIndexOrInsert(value, col); } else { std::string featureValue = config.getAsFeature(col, index); - if (w2vFile.empty()) - featureValue = fmt::format("{}({})", col, featureValue); - dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue)); + dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col); } for (auto & contextElement : context) @@ -214,6 +211,12 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) void ContextualModuleImpl::registerEmbeddings() { wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); - loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); + + auto pathes = util::split(w2vFiles.string(), ' '); + for (auto & p : pathes) + { + auto splited = util::split(p, ','); + loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); + } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 2cb88dce6f679bbc79a6cbf10b0558d27415f698..6d97fbe813a18184e55b254d17bdea9ce1b35f16 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -117,9 +117,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon for (int i = 0; i < maxElemPerDepth[depth]; i++) for (auto & col : columns) if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])))); + contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col)); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } } } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 40098bc37e96c2008cf4f331b761debfaaf022f9..daf7a3c1488bf6ee7334a5461d5c65e54ab71ced 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -86,6 +86,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, else toIndexes.emplace_back(-1); + std::string prefix = "DISTANCE"; + for (auto & contextElement : context) { for (auto from : fromIndexes) @@ -93,16 +95,16 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, { if (from == -1 or to == -1) { - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); continue; } long dist = std::abs(config.getRelativeDistance(from, to)); if (dist <= threshold) - contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("distance({})", dist))); + contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), "")); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix)); } } } diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 62c1aef1e364b2616ef48d6559468d944546115e..556fdc40f8387bf9b5d4dfdcdb5473cd5a553bbd 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -84,7 +84,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont if (index == -1) { for (int i = 0; i < maxNbElements; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column)); continue; } @@ -93,6 +93,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont { auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get())); + //TODO don't use nullValueStr here for (int i = 0; i < maxNbElements; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("{}", asUtf8[i])); @@ -105,23 +106,23 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont for (int i = 0; i < maxNbElements; i++) if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); + elements.emplace_back(splited[i]); else elements.emplace_back(Dict::nullValueStr); } else if (column == "ID") { if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); + elements.emplace_back("TOKEN"); else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); + elements.emplace_back("MULTIWORD"); else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); + elements.emplace_back("EMPTYNODE"); } else if (column == "EOS") { bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1; - elements.emplace_back(fmt::format("EOS({})", isEOS)); + elements.emplace_back(fmt::format("{}", isEOS)); } else { @@ -132,7 +133,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements)); for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element)); + contextElement.emplace_back(dict.getIndexOrInsert(element, column)); } } } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index eb5c28c339197914f7b81b8f6f6c5ce5bec8f4f2..724911699495a3a4febd58cfa36bd0e79a3d483b 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -57,12 +57,14 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { auto & dict = getDict(); + std::string prefix = "HISTORY"; + for (auto & contextElement : context) for (int i = 0; i < maxNbElements; i++) if (config.hasHistory(i)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i))); + contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix)); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); } void HistoryModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index 8f43a2fd310853793e93caaebc348ab1ff18b0be..d6adb74c7277b9e0fec2d8f6c9b660c174bfed08 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -57,20 +57,22 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, if (leftWindow < 0 or rightWindow < 0) return; + std::string prefix = "LETTER"; + auto & dict = getDict(); for (auto & contextElement : context) { for (int i = 0; i < leftWindow; i++) if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), "")); else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); for (int i = 0; i <= rightWindow; i++) if (config.hasCharacter(config.getCharacterIndex()+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), "")); else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); } } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index d4f6d84067329448d1c5b9b5f6ccba01157b279f..43964c696ff3965c450edb24a9a6ec21e53a531c 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -58,9 +58,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context for (auto & contextElement : context) for (int i = 0; i < maxNbTrans; i++) if (i < (int)splitTransitions.size()) - contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), "")); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, "")); } void SplitTransModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 42edd50ee4621080b512782378bf87e2c1703235..18627db6b58668b4d69f0412100e3619d6448f02 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -33,7 +33,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, { auto & dict = getDict(); for (auto & contextElement : context) - contextElement.emplace_back(dict.getIndexOrInsert(config.getState())); + contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), "")); } void StateNameModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index e52ef5e00a8760d550bdb1b4bdee0b38236acaaf..589bc96bfcd50fbc6ad4de6c42e3dfc47b89fbb8 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex) this->firstInputIndex = firstInputIndex; } -void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path) +void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix) { if (path.empty()) return; @@ -44,12 +44,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s if (splited.size() < 2) util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); - auto dictIndex = getDict().getIndexOrInsert(splited[0]); + std::string word; + if (splited[0] == "<unk>") - dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr); + word = Dict::unknownValueStr; + else + word = splited[0]; - if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))) - continue; + auto dictIndex = getDict().getIndexOrInsert(word, prefix); if (embeddingsSize != splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));