From 5800a6f34225ecc0a4f98853fe4033e8d81337e6 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 10 Oct 2021 10:15:39 +0200
Subject: [PATCH] Special embeddings can be trained even with lockPretrained

---
 common/include/Dict.hpp                       |  3 ++
 common/src/Dict.cpp                           | 23 ++++++++++++
 torch_modules/include/WordEmbeddings.hpp      |  7 ++--
 torch_modules/src/ContextModule.cpp           |  4 +--
 torch_modules/src/ContextualModule.cpp        |  4 +--
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  2 +-
 torch_modules/src/DistanceModule.cpp          |  2 +-
 torch_modules/src/FocusedColumnModule.cpp     |  4 +--
 torch_modules/src/HistoryMineModule.cpp       |  2 +-
 torch_modules/src/HistoryModule.cpp           |  2 +-
 torch_modules/src/RawInputModule.cpp          |  2 +-
 torch_modules/src/SplitTransModule.cpp        |  2 +-
 torch_modules/src/StateNameModule.cpp         |  2 +-
 torch_modules/src/WordEmbeddings.cpp          | 35 +++++++++++++++----
 14 files changed, 72 insertions(+), 22 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 7ff6e01..93774a1 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -6,6 +6,7 @@
 #include <vector>
 #include <filesystem>
 #include <mutex>
+#include <set>
 
 class Dict
 {
@@ -34,6 +35,7 @@ class Dict
   std::mutex elementsMutex;
   State state;
   bool isCountingOccs{false};
+  std::set<std::string> prefixes{""};
 
   public :
 
@@ -50,6 +52,7 @@ class Dict
   public :
 
   void countOcc(bool isCountingOccs);
+  std::set<std::size_t> getSpecialIndexes();
   int getIndexOrInsert(const std::string & element, const std::string & prefix);
   std::string getElement(std::size_t index);
   void setState(State state);
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index c6731cc..0eead58 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -94,6 +94,7 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
   if (state == State::Open)
     elementsMutex.lock();
 
+  prefixes.insert(prefix);
   int index = _getIndexOrInsert(element, prefix);
 
   if (state == State::Open)
@@ -350,6 +351,28 @@ bool Dict::isSpecialValue(const std::string & value)
   || value == urlValueStr;
 }
 
+std::set<std::size_t> Dict::getSpecialIndexes()
+{
+  auto oldState = getState();
+  setState(State::Closed);
+  std::set<std::string> specials = {
+    unknownValueStr,
+    nullValueStr,
+    oobValueStr,
+    noChildValueStr,
+    emptyValueStr,
+    separatorValueStr,
+    numberValueStr,
+    urlValueStr,
+  };
+  std::set<std::size_t> res;
+  for (auto & prefix : prefixes)
+    for (auto & special : specials)
+      res.insert(getIndexOrInsert(special, prefix));
+  setState(oldState);
+  return res;
+}
+
 std::string Dict::getElement(std::size_t index)
 {
   return indexesToElements[index];
diff --git a/torch_modules/include/WordEmbeddings.hpp b/torch_modules/include/WordEmbeddings.hpp
index c81d728..c9b225d 100644
--- a/torch_modules/include/WordEmbeddings.hpp
+++ b/torch_modules/include/WordEmbeddings.hpp
@@ -13,7 +13,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
 
   private :
   
-  torch::nn::Embedding embeddings{nullptr};
+  torch::nn::Embedding normalEmbeddings{nullptr};
+  torch::nn::Embedding specialEmbeddings{nullptr};
 
   public :
 
@@ -22,8 +23,8 @@ class WordEmbeddingsImpl : public torch::nn::Module
   static void setCanTrainPretrained(bool value);
   static bool getCanTrainPretrained();
 
-  WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
-  torch::nn::Embedding get();
+  WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes);
+  torch::nn::Embedding getNormalEmbeddings();
   torch::Tensor forward(torch::Tensor input);
 };
 TORCH_MODULE(WordEmbeddings);
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index ffea1d0..b99f6ea 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -187,12 +187,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 void ContextModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index d435648..6992524 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -234,13 +234,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 void ContextualModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
 
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 6945a34..a60433e 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -131,6 +131,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, co
 void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index fddf2e0..a51eea0 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -113,6 +113,6 @@ void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & co
 void DistanceModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 77f8c26..107e956 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -164,12 +164,12 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
 void FocusedColumnModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
     auto splited = util::split(p, ',');
-    loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
+    loadPretrainedW2vEmbeddings(wordEmbeddings->getNormalEmbeddings(), path / splited[1], splited[0]);
   }
 }
 
diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp
index cf8338d..25bfcc1 100644
--- a/torch_modules/src/HistoryMineModule.cpp
+++ b/torch_modules/src/HistoryMineModule.cpp
@@ -69,6 +69,6 @@ void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config &
 void HistoryMineModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 4a9033f..dddfdf7 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & con
 void HistoryModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 2d6bd62..d237485 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -87,6 +87,6 @@ void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & co
 void RawInputModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index dcb78e1..5f361a0 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config &
 void SplitTransModuleImpl::registerEmbeddings()
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index f3ac977..b5e81af 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & c
 void StateNameModuleImpl::registerEmbeddings()
 {
   if (!embeddings)
-    embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
+    embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize, std::set<std::size_t>()));
 }
 
diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp
index c931d6d..38a5e3b 100644
--- a/torch_modules/src/WordEmbeddings.cpp
+++ b/torch_modules/src/WordEmbeddings.cpp
@@ -1,20 +1,31 @@
 #include "WordEmbeddings.hpp"
+#include "util.hpp"
+#include "NeuralNetwork.hpp"
 
 bool WordEmbeddingsImpl::scaleGradByFreq = false;
 bool WordEmbeddingsImpl::canTrainPretrained = false;
 float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
 
-WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim)
+WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
 {
+  for (auto elem : specialIndexes)
+    if (elem >= specialIndexes.size())
+      util::error("Special indexes are not contiguous from zero.");
   if (maxNorm == std::numeric_limits<float>::max())
-    embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
+  {
+    normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
+    specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
+  }
   else
-    embeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
+  {
+    normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).max_norm(maxNorm).scale_grad_by_freq(scaleGradByFreq)));
+    specialEmbeddings = register_module("specialEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(specialIndexes.size(), dim).scale_grad_by_freq(scaleGradByFreq)));
+  }
 }
 
-torch::nn::Embedding WordEmbeddingsImpl::get()
+torch::nn::Embedding WordEmbeddingsImpl::getNormalEmbeddings()
 {
-  return embeddings;
+  return normalEmbeddings;
 }
 
 void WordEmbeddingsImpl::setScaleGradByFreq(bool scaleGradByFreq)
@@ -34,7 +45,19 @@ void WordEmbeddingsImpl::setCanTrainPretrained(bool value)
 
 torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
 {
-  return embeddings(input);
+  if (specialEmbeddings->weight.size(0) == 0)
+    return normalEmbeddings(input);
+
+  auto mask = input >= specialEmbeddings->weight.size(0);
+  auto specialIndexes = torch::ones(input.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
+  specialIndexes.index_put_({mask}, 0);
+  auto normalRes = normalEmbeddings(input);
+  auto specialRes = specialEmbeddings(input * specialIndexes);
+  auto normalIndexes = torch::ones(normalRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
+  specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
+  specialIndexes.index_put_({mask}, 0);
+  normalIndexes.index_put_({~mask}, 0);
+  return normalIndexes*normalRes + specialIndexes*specialRes;
 }
 
 bool WordEmbeddingsImpl::getCanTrainPretrained()
-- 
GitLab