From 1473579cd0b3e379ae3f2aa785b04941d1d26325 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 25 Mar 2021 12:26:07 +0100
Subject: [PATCH] Added sanity check when loading pretrained word embeddings

---
 torch_modules/src/Submodule.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 24dea4c..818a058 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -72,6 +72,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
 
       if (dictIndex >= embeddings->weight.size(0))
       {
+        if ((unsigned long)dictIndex != embeddings->weight.size(0)+toAdd.size())
+          util::myThrow(fmt::format("dictIndex == {}, weight.size == {}, toAdd.size == {}", dictIndex, embeddings->weight.size(0), toAdd.size()));
         toAdd.emplace_back();
         for (unsigned int i = 1; i < splited.size(); i++)
           toAdd.back().emplace_back(std::stof(splited[i]));
@@ -166,7 +168,7 @@ std::function<std::string(const std::string &)> Submodule::getFunction(const std
 
   return [sequence](const std::string & s)
   {
-    auto result = s; 
+    auto result = s;
     for (auto & f : sequence)
       result = f(result);
     return result;
-- 
GitLab