From 1473579cd0b3e379ae3f2aa785b04941d1d26325 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 25 Mar 2021 12:26:07 +0100 Subject: [PATCH] Added sanity check when loading pretrained word embeddings --- torch_modules/src/Submodule.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index 24dea4c..818a058 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -72,6 +72,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std if (dictIndex >= embeddings->weight.size(0)) { + if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size()) + util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size())); toAdd.emplace_back(); for (unsigned int i = 1; i < splited.size(); i++) toAdd.back().emplace_back(std::stof(splited[i])); @@ -166,7 +168,7 @@ std::function<std::string(const std::string &)> Submodule::getFunction(const std return [sequence](const std::string & s) { - auto result = s; + auto result = s; for (auto & f : sequence) result = f(result); return result; -- GitLab