From 1e98bc425fa5a527cb3715e0cb3fe74eb1c35cc3 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 16 Oct 2020 12:07:27 +0200
Subject: [PATCH] Dict is open during pretrained embeddings loading

---
 torch_modules/src/Submodule.cpp | 28 +++++++++++++++++++++++++---
 1 file changed, 25 insertions(+), 3 deletions(-)

diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 4152822..24dea4c 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -23,10 +23,12 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
   if (!std::filesystem::exists(path))
     util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
 
+  std::vector<std::vector<float>> toAdd;
+
   torch::NoGradGuard no_grad;
 
   auto originalState = getDict().getState();
-  getDict().setState(Dict::State::Closed);
+  getDict().setState(Dict::State::Open);
 
   std::FILE * file = std::fopen(path.c_str(), "r");
   char buffer[100000];
@@ -68,8 +70,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
       if (embeddingsSize != splited.size()-1)
         util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
 
-      for (unsigned int i = 1; i < splited.size(); i++)
-        embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
+      if (dictIndex >= embeddings->weight.size(0))
+      {
+        toAdd.emplace_back();
+        for (unsigned int i = 1; i < splited.size(); i++)
+          toAdd.back().emplace_back(std::stof(splited[i]));
+      }
+      else
+      {
+        for (unsigned int i = 1; i < splited.size(); i++)
+          embeddings->weight[dictIndex][i-1] = std::stof(splited[i]);
+      }
     }
   } catch (std::exception & e)
   {
@@ -81,6 +92,17 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std
   if (firstLine)
     util::myThrow(fmt::format("file '{}' is empty", path.string()));
 
+  if (!toAdd.empty())
+  {
+    auto newEmb = torch::nn::Embedding(embeddings->weight.size(0)+toAdd.size(), embeddingsSize);
+    for (unsigned int i = 0; i < embeddings->weight.size(0); i++)
+      newEmb->weight[i] = embeddings->weight[i];
+    for (unsigned int i = 0; i < toAdd.size(); i++)
+      for (unsigned int j = 0; j < embeddingsSize; j++)
+        newEmb->weight[embeddings->weight.size(0)+i][j] = toAdd[i][j];
+    embeddings->weight = newEmb->weight;
+  }
+
   getDict().setState(originalState);
   embeddings->weight.set_requires_grad(WordEmbeddingsImpl::getCanTrainPretrained());
 }
-- 
GitLab