From dfd75ada672f65f3717bd54b76330a3a011e1303 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 12 Nov 2021 20:56:32 +0100
Subject: [PATCH] Having separate wordEmbeddings for special values

---
 common/include/Dict.hpp                   |  1 +
 common/src/Dict.cpp                       | 31 ++++++++++++++++-------
 torch_modules/src/ContextModule.cpp       |  2 +-
 torch_modules/src/ContextualModule.cpp    |  2 +-
 torch_modules/src/FocusedColumnModule.cpp |  2 +-
 torch_modules/src/WordEmbeddings.cpp      |  3 +++
 6 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index fa33fef..d37ff65 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -45,6 +45,7 @@ class Dict
 
   private :
 
+  void addPrefixValues(std::string prefix);
   void readFromFile(const char * filename);
   void insert(const std::string & element);
   void reset();
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 2a4f118..82b0c8d 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -5,14 +5,7 @@ Dict::Dict(State state)
 {
   locked = false;
   setState(state);
-  insert(unknownValueStr);
-  insert(nullValueStr);
-  insert(oobValueStr);
-  insert(noChildValueStr);
-  insert(emptyValueStr);
-  insert(numberValueStr);
-  insert(urlValueStr);
-  insert(separatorValueStr);
+  addPrefixValues("");
 }
 
 Dict::Dict(const char * filename, State state)
@@ -22,6 +15,17 @@ Dict::Dict(const char * filename, State state)
   locked = false;
 }
 
+void Dict::addPrefixValues(std::string prefix)
+{
+  for (auto & element : {unknownValueStr, nullValueStr, oobValueStr, noChildValueStr, emptyValueStr, numberValueStr, urlValueStr, separatorValueStr})
+  {
+    std::string prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
+    if (!elementsToIndexes.count(prefixed))
+  
+      insert(prefixed);
+  }
+}
+
 void Dict::lock()
 {
   locked = true;
@@ -64,6 +68,11 @@ void Dict::readFromFile(const char * filename)
     if (!readEntry(file, &entryIndex, &nbOccsEntry, entryString, encoding))
       util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
 
+    std::string prefix = "";
+    auto splited = util::split(entryString, '(');
+    if (splited.size() > 1)
+      prefix = splited[0];
+    prefixes.insert(prefix);
     if (elementsToIndexes.count(entryString))
       util::myThrow(fmt::format("entry '{}' is already in dict", entryString));
     if (indexesToElements.count(entryIndex))
@@ -101,7 +110,6 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
   if (state == State::Open)
     elementsMutex.lock();
 
-  prefixes.insert(prefix);
   int index = _getIndexOrInsert(element, prefix);
 
   if (state == State::Open)
@@ -112,6 +120,11 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
 
 int Dict::_getIndexOrInsert(const std::string & element, const std::string & prefix)
 {
+  if (!prefixes.count(prefix))
+  {
+    prefixes.insert(prefix);
+    addPrefixValues(prefix);
+  }
   if (element.empty())
     return _getIndexOrInsert(emptyValueStr, prefix);
 
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 48f9a00..ae92243 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -187,7 +187,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
 void ContextModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize,  getDict().getSpecialIndexes()));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 564c95f..3c81258 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -234,7 +234,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 void ContextualModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
 
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 23ebe6f..1a0a9d3 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -164,7 +164,7 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
 void FocusedColumnModuleImpl::registerEmbeddings(bool loadPretrained)
 {
   if (!wordEmbeddings)
-    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, w2vFiles.empty() ? std::set<std::size_t>() : getDict().getSpecialIndexes()));
+    wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize, getDict().getSpecialIndexes()));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
   {
diff --git a/torch_modules/src/WordEmbeddings.cpp b/torch_modules/src/WordEmbeddings.cpp
index 38a5e3b..045296f 100644
--- a/torch_modules/src/WordEmbeddings.cpp
+++ b/torch_modules/src/WordEmbeddings.cpp
@@ -9,8 +9,10 @@ float WordEmbeddingsImpl::maxNorm = std::numeric_limits<float>::max();
 WordEmbeddingsImpl::WordEmbeddingsImpl(std::size_t vocab, std::size_t dim, std::set<std::size_t> specialIndexes)
 {
   for (auto elem : specialIndexes)
+  {
     if (elem >= specialIndexes.size())
       util::error("Special indexes are not contiguous from zero.");
+  }
   if (maxNorm == std::numeric_limits<float>::max())
   {
     normalEmbeddings = register_module("normalEmbeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(vocab, dim).scale_grad_by_freq(scaleGradByFreq)));
@@ -57,6 +59,7 @@ torch::Tensor WordEmbeddingsImpl::forward(torch::Tensor input)
   specialIndexes = torch::ones(specialRes.sizes(),torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
   specialIndexes.index_put_({mask}, 0);
   normalIndexes.index_put_({~mask}, 0);
+
   return normalIndexes*normalRes + specialIndexes*specialRes;
 }
 
-- 
GitLab