From 675d8f42c8ae6e201aa031517c40e2f17684ef20 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 30 Jul 2020 21:40:34 +0200
Subject: [PATCH] allow multiple pretrained embeddings file for ContextModule

---
 torch_modules/include/ContextModule.hpp |  2 +-
 torch_modules/src/ContextModule.cpp     | 14 +++++++++-----
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp
index 508ba71..fc24680 100644
--- a/torch_modules/include/ContextModule.hpp
+++ b/torch_modules/include/ContextModule.hpp
@@ -21,7 +21,7 @@ class ContextModuleImpl : public Submodule
   std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
   int inSize;
   std::filesystem::path path;
-  std::filesystem::path w2vFile;
+  std::filesystem::path w2vFiles;
 
   public :
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 9016938..66a6728 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -48,11 +48,13 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
-            w2vFile = sm.str(7);
+            w2vFiles = sm.str(7);
 
-            if (!w2vFile.empty())
+            if (!w2vFiles.empty())
             {
-              getDict().loadWord2Vec(this->path / w2vFile);
+              auto pathes = util::split(w2vFiles.string(), ' ');
+              for (auto & p : pathes)
+              getDict().loadWord2Vec(this->path / p);
               getDict().setState(Dict::State::Closed);
               dictSetPretrained(true);
             }
@@ -138,7 +140,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
         else
         {
           std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
-          if (w2vFile.empty())
+          if (w2vFiles.empty())
             featureValue = fmt::format("{}({})", col, featureValue);
           dictIndex = dict.getIndexOrInsert(featureValue);
         }
@@ -161,6 +163,8 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 void ContextModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
+  auto pathes = util::split(w2vFiles.string(), ' ');
+  for (auto & p : pathes)
+    loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
 }
 
-- 
GitLab