From e925b55938d60348fc4343b36128a31b8414b27f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 28 Jul 2020 17:53:48 +0200
Subject: [PATCH] <unk> in pretrained embeddgins is used for
 Dict::unknownValueStr

---
 torch_modules/src/Submodule.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 66f6455..e52ef5e 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -45,8 +45,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
         util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
 
       auto dictIndex = getDict().getIndexOrInsert(splited[0]);
+      if (splited[0] == "<unk>")
+        dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr);
 
-      if (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr))
+      if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)))
         continue;
 
       if (embeddingsSize != splited.size()-1)
-- 
GitLab