From e925b55938d60348fc4343b36128a31b8414b27f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 28 Jul 2020 17:53:48 +0200 Subject: [PATCH] <unk> in pretrained embeddgins is used for Dict::unknownValueStr --- torch_modules/src/Submodule.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 66f6455..e52ef5e 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -45,8 +45,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer)); auto dictIndex = getDict().getIndexOrInsert(splited[0]); + if (splited[0] == "<unk>") + dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr); - if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)) + if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))) continue; if (embeddingsSize != splited.size()-1) -- GitLab