From 397e390f75da72065b78f1f72592f20be0cbd9ec Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 31 Jul 2020 17:53:24 +0200
Subject: [PATCH] FocusedModule can now have pretraiend word embeddings

---
 torch_modules/include/FocusedColumnModule.hpp |  4 ++-
 torch_modules/src/FocusedColumnModule.cpp     | 26 +++++++++++++++++--
 torch_modules/src/ModularNetwork.cpp          |  2 +-
 torch_modules/src/Submodule.cpp               |  4 +++
 4 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp
index 024c6c1..85a55de 100644
--- a/torch_modules/include/FocusedColumnModule.hpp
+++ b/torch_modules/include/FocusedColumnModule.hpp
@@ -19,10 +19,12 @@ class FocusedColumnModuleImpl : public Submodule
   std::function<std::string(const std::string&)> func{[](const std::string &s){return s;}};
   int maxNbElements;
   int inSize;
+  std::filesystem::path path;
+  std::filesystem::path w2vFiles;
 
   public :
 
-  FocusedColumnModuleImpl(std::string name, const std::string & definition);
+  FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
   torch::Tensor forward(torch::Tensor input);
   std::size_t getOutputSize() override;
   std::size_t getInputSize() override;
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 556fdc4..1ed8da9 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -1,9 +1,9 @@
 #include "FocusedColumnModule.hpp"
 
-FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition)
+FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
 {
   setName(name);
-  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
           try
@@ -39,6 +39,22 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
+            w2vFiles = sm.str(9);
+
+            if (!w2vFiles.empty())
+            {
+              auto pathes = util::split(w2vFiles.string(), ' ');
+              for (auto & p : pathes)
+              {
+                auto splited = util::split(p, ',');
+                if (splited.size() != 2)
+                  util::myThrow("expected 'prefix,pretrained.w2v'");
+                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                getDict().setState(Dict::State::Closed);
+                dictSetPretrained(true);
+              }
+            }
+
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
         }))
     util::myThrow(fmt::format("invalid definition '{}'", definition));
@@ -141,5 +157,11 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 void FocusedColumnModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
+  auto pathes = util::split(w2vFiles.string(), ' ');
+  for (auto & p : pathes)
+  {
+    auto splited = util::split(p, ',');
+    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+  }
 }
 
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 7dcf1c5..e2e225c 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -40,7 +40,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
     else if (splited.first == "UppercaseRate")
       modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second)));
     else if (splited.first == "Focused")
-      modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second)));
+      modules.emplace_back(register_module(name, FocusedColumnModule(nameH, splited.second, path)));
     else if (splited.first == "RawInput")
       modules.emplace_back(register_module(name, RawInputModule(nameH, splited.second)));
     else if (splited.first == "SplitTrans")
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index 589bc96..7681c9e 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -51,6 +51,10 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
       else
         word = splited[0];
 
+      auto toInsert = util::splitAsUtf8(word);
+      toInsert.replace("◌", " ");
+      word = fmt::format("{}", toInsert);
+
       auto dictIndex = getDict().getIndexOrInsert(word, prefix);
 
       if (embeddingsSize != splited.size()-1)
-- 
GitLab