From 73971f7c15446dca4d26c50de3c24c6669fe9532 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 9 Oct 2021 17:22:12 +0200 Subject: [PATCH] Allow spaces in w2v --- torch_modules/src/Submodule.cpp | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 818a058..51fff9d 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -61,27 +61,36 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std else word = splited[0]; + for (unsigned int i = 1; i < ((int)splited.size()-embeddingsSize); i++) + word += " "+splited[i]; + auto toInsert = util::splitAsUtf8(word); toInsert.replace("◌", " "); word = fmt::format("{}", toInsert); auto dictIndex = getDict().getIndexOrInsert(word, prefix); - if (embeddingsSize != splited.size()-1) + if (embeddingsSize > splited.size()-1) util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1)); - if (dictIndex >= embeddings->weight.size(0)) + try { - if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size()) - util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size())); - toAdd.emplace_back(); - for (unsigned int i = 1; i < splited.size(); i++) - toAdd.back().emplace_back(std::stof(splited[i])); - } - else + if (dictIndex >= embeddings->weight.size(0)) + { + if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size()) + util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size())); + toAdd.emplace_back(); + for (unsigned int i = splited.size()-embeddingsSize; i < splited.size(); i++) + toAdd.back().emplace_back(std::stof(splited[i])); + } + else + { + for (unsigned int i = splited.size()-embeddingsSize; i < splited.size(); i++) + embeddings->weight[dictIndex][i-(splited.size()-embeddingsSize)] = std::stof(splited[i]); + } + } catch (std::exception & e) { - for (unsigned int i = 1; i < splited.size(); i++) - embeddings->weight[dictIndex][i-1] = std::stof(splited[i]); + util::myThrow(fmt::format("{} in line\n{}\n", e.what(), buffer)); } } } catch (std::exception & e) -- GitLab