From 7cf7a308a175d4ff89dc70291c28a86c816b4e99 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 19 Mar 2020 15:21:44 +0100
Subject: [PATCH] Dict dont insert separators

---
 common/include/Dict.hpp           | 1 +
 common/include/utf8string.hpp     | 1 +
 common/src/Dict.cpp               | 5 ++++-
 common/src/utf8string.cpp         | 5 +++++
 torch_modules/src/CNNNetwork.cpp  | 6 +++---
 torch_modules/src/LSTMNetwork.cpp | 6 +++---
 6 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index a5ae772..8bc9d3a 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -17,6 +17,7 @@ class Dict
   static constexpr char const * unknownValueStr = "__unknownValue__";
   static constexpr char const * nullValueStr = "__nullValue__";
   static constexpr char const * emptyValueStr = "__emptyValue__";
+  static constexpr char const * separatorValueStr = "__separatorValue__";
   static constexpr std::size_t maxEntrySize = 5000;
 
   private :
diff --git a/common/include/utf8string.hpp b/common/include/utf8string.hpp
index ddd468e..42fe5d4 100644
--- a/common/include/utf8string.hpp
+++ b/common/include/utf8string.hpp
@@ -14,6 +14,7 @@ class utf8char : public std::array<char, 4>
   public :
 
   utf8char();
+  utf8char(const std::string & other);
   utf8char & operator=(char other);
   utf8char & operator=(const std::string & other);
   bool operator==(char other);
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 6154dc1..b75457e 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -50,7 +50,7 @@ void Dict::readFromFile(const char * filename)
       util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
 
     elementsToIndexes[entryString] = entryIndex;
-    while (nbOccs.size() <= entryIndex)
+    while ((int)nbOccs.size() <= entryIndex)
       nbOccs.emplace_back(0);
     nbOccs[entryIndex] = nbOccsEntry;
   }
@@ -73,6 +73,9 @@ int Dict::getIndexOrInsert(const std::string & element)
   if (element.empty())
     return getIndexOrInsert(emptyValueStr);
 
+  if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
+    return getIndexOrInsert(separatorValueStr);
+
   const auto & found = elementsToIndexes.find(element);
 
   if (found == elementsToIndexes.end())
diff --git a/common/src/utf8string.cpp b/common/src/utf8string.cpp
index 688b607..d430197 100644
--- a/common/src/utf8string.cpp
+++ b/common/src/utf8string.cpp
@@ -7,6 +7,11 @@ util::utf8char::utf8char()
     val = '\0';
 }
 
+util::utf8char::utf8char(const std::string & other)
+{
+  *this = other;
+}
+
 util::utf8char & util::utf8char::operator=(char other)
 {
   (*this)[0] = other;
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 96c6b28..68bd974 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -87,13 +87,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
   {
     for (int i = 0; i < leftWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
       else
         context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
     for (int i = 0; i <= rightWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
       else
         context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
   }
@@ -159,7 +159,7 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
 
           for (int i = 0; i < maxNbElements[colIndex]; i++)
             if (i < (int)asUtf8.size())
-              elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
+              elements.emplace_back(fmt::format("{}", asUtf8[i]));
             else
               elements.emplace_back(Dict::nullValueStr);
         }
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 84c9e79..734770e 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -105,13 +105,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
   {
     for (int i = 0; i < leftWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
       else
         context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
     for (int i = 0; i <= rightWindowRawInput; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
       else
         context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
   }
@@ -177,7 +177,7 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
 
           for (int i = 0; i < maxNbElements[colIndex]; i++)
             if (i < (int)asUtf8.size())
-              elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
+              elements.emplace_back(fmt::format("{}", asUtf8[i]));
             else
               elements.emplace_back(Dict::nullValueStr);
         }
-- 
GitLab