From 57db2a2e15f62c7e0e7b627313ce99fa0dcab4df Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 31 Jul 2020 17:24:12 +0200
Subject: [PATCH] Changed the way prefix are handled in dicts

---
 common/include/Dict.hpp                       |  7 +-
 common/src/Dict.cpp                           | 65 ++++++++++++++-----
 torch_modules/include/ContextualModule.hpp    |  2 +-
 torch_modules/include/Submodule.hpp           |  2 +-
 torch_modules/src/ContextModule.cpp           | 34 +++++-----
 torch_modules/src/ContextualModule.cpp        | 51 ++++++++-------
 .../src/DepthLayerTreeEmbeddingModule.cpp     |  4 +-
 torch_modules/src/DistanceModule.cpp          |  8 ++-
 torch_modules/src/FocusedColumnModule.cpp     | 15 +++--
 torch_modules/src/HistoryModule.cpp           |  6 +-
 torch_modules/src/RawInputModule.cpp          | 10 +--
 torch_modules/src/SplitTransModule.cpp        |  4 +-
 torch_modules/src/StateNameModule.cpp         |  2 +-
 torch_modules/src/Submodule.cpp               | 12 ++--
 14 files changed, 137 insertions(+), 85 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index efd5806..dda547b 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -26,6 +26,7 @@ class Dict
   private :
 
   std::unordered_map<std::string, int> elementsToIndexes;
+  std::unordered_map<int, std::string> indexesToElements;
   std::vector<int> nbOccs;
   State state;
   bool isCountingOccs{false};
@@ -43,7 +44,8 @@ class Dict
   public :
 
   void countOcc(bool isCountingOccs);
-  int getIndexOrInsert(const std::string & element);
+  int getIndexOrInsert(const std::string & element, const std::string & prefix);
+  std::string getElement(std::size_t index);
   void setState(State state);
   State getState() const;
   void save(std::filesystem::path path, Encoding encoding) const;
@@ -52,7 +54,8 @@ class Dict
   std::size_t size() const;
   int getNbOccs(int index) const;
   void removeRareElements();
-  void loadWord2Vec(std::filesystem::path path);
+  void loadWord2Vec(std::filesystem::path path, std::string prefix);
+  bool isSpecialValue(const std::string & value);
 };
 
 #endif
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 49e678f..d4e7ba2 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -42,6 +42,7 @@ void Dict::readFromFile(const char * filename)
     util::myThrow(fmt::format("file '{}' bad format", filename));
 
   elementsToIndexes.reserve(nbEntries);
+  indexesToElements.reserve(nbEntries);
 
   int entryIndex;
   int nbOccsEntry;
@@ -52,6 +53,7 @@ void Dict::readFromFile(const char * filename)
       util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
 
     elementsToIndexes[entryString] = entryIndex;
+    indexesToElements[entryIndex] = entryString;
     while ((int)nbOccs.size() <= entryIndex)
       nbOccs.emplace_back(0);
     nbOccs[entryIndex] = nbOccsEntry;
@@ -66,37 +68,40 @@ void Dict::insert(const std::string & element)
     util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
 
   elementsToIndexes.emplace(element, elementsToIndexes.size());
+  indexesToElements.emplace(elementsToIndexes.size()-1, element);
   while (nbOccs.size() < elementsToIndexes.size())
     nbOccs.emplace_back(0);
 }
 
-int Dict::getIndexOrInsert(const std::string & element)
+int Dict::getIndexOrInsert(const std::string & element, const std::string & prefix)
 {
   if (element.empty())
-    return getIndexOrInsert(emptyValueStr);
+    return getIndexOrInsert(emptyValueStr, prefix);
 
   if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
-    return getIndexOrInsert(separatorValueStr);
+    return getIndexOrInsert(separatorValueStr, prefix);
 
   if (util::isNumber(element))
-    return getIndexOrInsert(numberValueStr);
+    return getIndexOrInsert(numberValueStr, prefix);
 
   if (util::isUrl(element))
-    return getIndexOrInsert(urlValueStr);
+    return getIndexOrInsert(urlValueStr, prefix);
 
-  const auto & found = elementsToIndexes.find(element);
+  auto prefixed = prefix.empty() ? element : fmt::format("{}({})", prefix, element);
+  const auto & found = elementsToIndexes.find(prefixed);
 
   if (found == elementsToIndexes.end())
   {
     if (state == State::Open)
     {
-      insert(element);
+      insert(prefixed);
       if (isCountingOccs)
-        nbOccs[elementsToIndexes[element]]++;
-      return elementsToIndexes[element];
+        nbOccs[elementsToIndexes[prefixed]]++;
+      return elementsToIndexes[prefixed];
     }
 
-    const auto & found2 = elementsToIndexes.find(util::lower(element));
+    prefixed = prefix.empty() ? util::lower(element) : fmt::format("{}({})", prefix, util::lower(element));
+    const auto & found2 = elementsToIndexes.find(prefixed);
     if (found2 != elementsToIndexes.end())
     {
       if (isCountingOccs)
@@ -104,9 +109,10 @@ int Dict::getIndexOrInsert(const std::string & element)
       return found2->second;   
     }
 
+    prefixed = prefix.empty() ? unknownValueStr : fmt::format("{}({})", prefix, unknownValueStr);
     if (isCountingOccs)
-      nbOccs[elementsToIndexes[unknownValueStr]]++;
-    return elementsToIndexes[unknownValueStr];
+      nbOccs[elementsToIndexes[prefixed]]++;
+    return elementsToIndexes[prefixed];
   }
 
   if (isCountingOccs)
@@ -217,7 +223,7 @@ void Dict::removeRareElements()
   nbOccs = newNbOccs;
 }
 
-void Dict::loadWord2Vec(std::filesystem::path path)
+void Dict::loadWord2Vec(std::filesystem::path path, std::string prefix)
 {
    if (path.empty())
     return;
@@ -235,6 +241,16 @@ void Dict::loadWord2Vec(std::filesystem::path path)
 
   try
   {
+    if (!prefix.empty())
+    {
+      std::vector<std::string> toAdd;
+      for (auto & it : elementsToIndexes)
+        if (isSpecialValue(it.first))
+          toAdd.emplace_back(fmt::format("{}({})", prefix, it.first));
+      for (auto & elem : toAdd)
+        getIndexOrInsert(elem, "");
+    }
+
     while (!std::feof(file))
     {
       if (buffer != std::fgets(buffer, 100000, file))
@@ -251,9 +267,13 @@ void Dict::loadWord2Vec(std::filesystem::path path)
       if (splited.size() < 2)
         util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
 
-      auto dictIndex = getIndexOrInsert(splited[0]);
+      if (splited[0] == "<unk>")
+        continue;
+      auto toInsert = util::splitAsUtf8(splited[0]);
+      toInsert.replace("◌", " ");
+      auto dictIndex = getIndexOrInsert(fmt::format("{}", toInsert), prefix);
 
-      if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
+      if (dictIndex == getIndexOrInsert(Dict::unknownValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::nullValueStr, prefix) or dictIndex == getIndexOrInsert(Dict::emptyValueStr, prefix))
         util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
     }
   } catch (std::exception & e)
@@ -269,3 +289,18 @@ void Dict::loadWord2Vec(std::filesystem::path path)
   setState(originalState);
 }
 
+bool Dict::isSpecialValue(const std::string & value)
+{
+  return value == unknownValueStr
+  || value == nullValueStr
+  || value == emptyValueStr
+  || value == separatorValueStr
+  || value == numberValueStr
+  || value == urlValueStr;
+}
+
+std::string Dict::getElement(std::size_t index)
+{
+  return indexesToElements[index];
+}
+
diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp
index d7e290c..0395c11 100644
--- a/torch_modules/include/ContextualModule.hpp
+++ b/torch_modules/include/ContextualModule.hpp
@@ -22,7 +22,7 @@ class ContextualModuleImpl : public Submodule
   int inSize;
   int outSize;
   std::filesystem::path path;
-  std::filesystem::path w2vFile;
+  std::filesystem::path w2vFiles;
 
   public :
 
diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp
index 70250e0..77c0346 100644
--- a/torch_modules/include/Submodule.hpp
+++ b/torch_modules/include/Submodule.hpp
@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
   public :
 
   void setFirstInputIndex(std::size_t firstInputIndex);
-  void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path);
+  void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix);
   virtual std::size_t getOutputSize() = 0;
   virtual std::size_t getInputSize() = 0;
   virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 66a6728..c83de18 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -54,9 +54,14 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
             {
               auto pathes = util::split(w2vFiles.string(), ' ');
               for (auto & p : pathes)
-              getDict().loadWord2Vec(this->path / p);
-              getDict().setState(Dict::State::Closed);
-              dictSetPretrained(true);
+              {
+                auto splited = util::split(p, ',');
+                if (splited.size() != 2)
+                  util::myThrow("expected 'prefix,pretrained.w2v'");
+                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                getDict().setState(Dict::State::Closed);
+                dictSetPretrained(true);
+              }
             }
 
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
@@ -117,7 +122,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       if (index == -1)
       {
         for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
       }
       else
       {
@@ -126,23 +131,17 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
         {
           std::string value;
           if (config.isCommentPredicted(index))
-            value = "ID(comment)";
+            value = "comment";
           else if (config.isMultiwordPredicted(index))
-            value = "ID(multiword)";
+            value = "multiword";
           else if (config.isTokenPredicted(index))
-            value = "ID(token)";
-          dictIndex = dict.getIndexOrInsert(value);
-        }
-        else if (col == Config::EOSColName)
-        {
-          dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
+            value = "token";
+          dictIndex = dict.getIndexOrInsert(value, col);
         }
         else
         {
           std::string featureValue = functions[colIndex](config.getAsFeature(col, index));
-          if (w2vFiles.empty())
-            featureValue = fmt::format("{}({})", col, featureValue);
-          dictIndex = dict.getIndexOrInsert(featureValue);
+          dictIndex = dict.getIndexOrInsert(featureValue, col);
         }
 
         for (auto & contextElement : context)
@@ -165,6 +164,9 @@ void ContextModuleImpl::registerEmbeddings()
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
   auto pathes = util::split(w2vFiles.string(), ' ');
   for (auto & p : pathes)
-    loadPretrainedW2vEmbeddings(wordEmbeddings, path / p);
+  {
+    auto splited = util::split(p, ',');
+    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+  }
 }
 
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index ebe386a..cc06903 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -53,13 +53,20 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
             else
               util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
 
-            w2vFile = sm.str(7);
+            w2vFiles = sm.str(7);
 
-            if (!w2vFile.empty())
+            if (!w2vFiles.empty())
             {
-              getDict().loadWord2Vec(this->path / w2vFile);
-              getDict().setState(Dict::State::Closed);
-              dictSetPretrained(true);
+              auto pathes = util::split(w2vFiles.string(), ' ');
+              for (auto & p : pathes)
+              {
+                auto splited = util::split(p, ',');
+                if (splited.size() != 2)
+                  util::myThrow("expected 'prefix,file.w2v'");
+                getDict().loadWord2Vec(this->path / splited[1], splited[0]);
+                getDict().setState(Dict::State::Closed);
+                dictSetPretrained(true);
+              }
             }
 
           } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
@@ -127,17 +134,13 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
       if (index == -1)
       {
         for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr)));
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
       }
       else if (index == -2)
       {
+        //TODO maybe change this to a unique value like Dict::noneValueStr
         for (auto & contextElement : context)
-        {
-          auto currentState = dict.getState();
-          dict.setState(Dict::State::Open);
-          contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_")));
-          dict.setState(currentState);
-        }
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
       }
       else
       {
@@ -146,23 +149,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
         {
           std::string value;
           if (config.isCommentPredicted(index))
-            value = "ID(comment)";
+            value = "comment";
           else if (config.isMultiwordPredicted(index))
-            value = "ID(multiword)";
+            value = "multiword";
           else if (config.isTokenPredicted(index))
-            value = "ID(token)";
-          dictIndex = dict.getIndexOrInsert(value);
-        }
-        else if (col == Config::EOSColName)
-        {
-          dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
+            value = "token";
+          dictIndex = dict.getIndexOrInsert(value, col);
         }
         else
         {
           std::string featureValue = config.getAsFeature(col, index);
-          if (w2vFile.empty())
-            featureValue = fmt::format("{}({})", col, featureValue);
-          dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
+          dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col);
         }
 
         for (auto & contextElement : context)
@@ -214,6 +211,12 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 void ContextualModuleImpl::registerEmbeddings()
 {
   wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
-  loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
+
+  auto pathes = util::split(w2vFiles.string(), ' ');
+  for (auto & p : pathes)
+  {
+    auto splited = util::split(p, ',');
+    loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
+  }
 }
 
diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
index 2cb88dc..6d97fbe 100644
--- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
@@ -117,9 +117,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
         for (int i = 0; i < maxElemPerDepth[depth]; i++)
           for (auto & col : columns)
             if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
-              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i]))));
+              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col));
             else
-              contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+              contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
       }
     }
 }
diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp
index 40098bc..daf7a3c 100644
--- a/torch_modules/src/DistanceModule.cpp
+++ b/torch_modules/src/DistanceModule.cpp
@@ -86,6 +86,8 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
     else
       toIndexes.emplace_back(-1);
 
+  std::string prefix = "DISTANCE";
+
   for (auto & contextElement : context)
   {
     for (auto from : fromIndexes)
@@ -93,16 +95,16 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
       {
         if (from == -1 or to == -1)
         {
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
           continue;
         }
 
         long dist = std::abs(config.getRelativeDistance(from, to));
 
         if (dist <= threshold)
-          contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("distance({})", dist)));
+          contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), ""));
         else
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr));
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix));
       }
   }
 }
diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp
index 62c1aef..556fdc4 100644
--- a/torch_modules/src/FocusedColumnModule.cpp
+++ b/torch_modules/src/FocusedColumnModule.cpp
@@ -84,7 +84,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
       if (index == -1)
       {
         for (int i = 0; i < maxNbElements; i++)
-          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column));
         continue;
       }
 
@@ -93,6 +93,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
       {
         auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get()));
 
+        //TODO don't use nullValueStr here
         for (int i = 0; i < maxNbElements; i++)
           if (i < (int)asUtf8.size())
             elements.emplace_back(fmt::format("{}", asUtf8[i]));
@@ -105,23 +106,23 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
 
         for (int i = 0; i < maxNbElements; i++)
           if (i < (int)splited.size())
-            elements.emplace_back(fmt::format("FEATS({})", splited[i]));
+            elements.emplace_back(splited[i]);
           else
             elements.emplace_back(Dict::nullValueStr);
       }
       else if (column == "ID")
       {
         if (config.isTokenPredicted(index))
-          elements.emplace_back("ID(TOKEN)");
+          elements.emplace_back("TOKEN");
         else if (config.isMultiwordPredicted(index))
-          elements.emplace_back("ID(MULTIWORD)");
+          elements.emplace_back("MULTIWORD");
         else if (config.isEmptyNodePredicted(index))
-          elements.emplace_back("ID(EMPTYNODE)");
+          elements.emplace_back("EMPTYNODE");
       }
       else if (column == "EOS")
       {
         bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1;
-        elements.emplace_back(fmt::format("EOS({})", isEOS));
+        elements.emplace_back(fmt::format("{}", isEOS));
       }
       else
       {
@@ -132,7 +133,7 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
         util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements));
 
       for (auto & element : elements)
-        contextElement.emplace_back(dict.getIndexOrInsert(element));
+        contextElement.emplace_back(dict.getIndexOrInsert(element, column));
     }
   }
 }
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index eb5c28c..7249116 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -57,12 +57,14 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
 {
   auto & dict = getDict();
 
+  std::string prefix = "HISTORY";
+
   for (auto & contextElement : context)
     for (int i = 0; i < maxNbElements; i++)
       if (config.hasHistory(i))
-        contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i)));
+        contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix));
       else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
 }
 
 void HistoryModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 8f43a2f..d6adb74 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -57,20 +57,22 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   if (leftWindow < 0 or rightWindow < 0)
     return;
 
+  std::string prefix = "LETTER";
+
   auto & dict = getDict();
   for (auto & contextElement : context)
   {
     for (int i = 0; i < leftWindow; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), ""));
       else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
 
     for (int i = 0; i <= rightWindow; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), ""));
       else
-        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
   }
 }
 
diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp
index d4f6d84..43964c6 100644
--- a/torch_modules/src/SplitTransModule.cpp
+++ b/torch_modules/src/SplitTransModule.cpp
@@ -58,9 +58,9 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
   for (auto & contextElement : context)
     for (int i = 0; i < maxNbTrans; i++)
       if (i < (int)splitTransitions.size())
-        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName()));
+        contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), ""));
       else
-        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, ""));
 }
 
 void SplitTransModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp
index 42edd50..18627db 100644
--- a/torch_modules/src/StateNameModule.cpp
+++ b/torch_modules/src/StateNameModule.cpp
@@ -33,7 +33,7 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
 {
   auto & dict = getDict();
   for (auto & contextElement : context)
-    contextElement.emplace_back(dict.getIndexOrInsert(config.getState()));
+    contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), ""));
 }
 
 void StateNameModuleImpl::registerEmbeddings()
diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp
index e52ef5e..589bc96 100644
--- a/torch_modules/src/Submodule.cpp
+++ b/torch_modules/src/Submodule.cpp
@@ -5,7 +5,7 @@ void Submodule::setFirstInputIndex(std::size_t firstInputIndex)
   this->firstInputIndex = firstInputIndex;
 }
 
-void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path)
+void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix)
 {
   if (path.empty())
     return;
@@ -44,12 +44,14 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s
       if (splited.size() < 2)
         util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
 
-      auto dictIndex = getDict().getIndexOrInsert(splited[0]);
+      std::string word;
+
       if (splited[0] == "<unk>")
-        dictIndex = getDict().getIndexOrInsert(Dict::unknownValueStr);
+        word = Dict::unknownValueStr;
+      else
+        word = splited[0];
 
-      if (splited[0] != "<unk>" and splited[0] != Dict::unknownValueStr and (dictIndex == getDict().getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::nullValueStr) or dictIndex == getDict().getIndexOrInsert(Dict::emptyValueStr)))
-        continue;
+      auto dictIndex = getDict().getIndexOrInsert(word, prefix);
 
       if (embeddingsSize != splited.size()-1)
         util::myThrow(fmt::format("in line \n{}embeddingsSize='{}' mismatch pretrainedEmbeddingSize='{}'", buffer, embeddingsSize, ((int)splited.size())-1));
-- 
GitLab