diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 66f645547d876eab18c6e18971404013f67fac31..e52ef5e00a8760d550bdb1b4bdee0b38236acaaf 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -45,8 +45,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); auto dictIndex = getDict().getIndexOrInsert(splited[0]); + if (splited[0] == "<unk>") + dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr); - if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)) + if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))) continue; if (embeddingsSize != splited.size()-1)