From 3ada140da51c757fe82d878b086b87296b84deec Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 5 Feb 2021 12:58:55 +0100
Subject: [PATCH] SplitTransition now ignore case

---
 common/include/util.hpp              |  6 ++++--
 common/src/Dict.cpp                  |  5 ++++-
 common/src/util.cpp                  | 19 +++++++++++++------
 reading_machine/include/Action.hpp   |  2 +-
 reading_machine/include/Config.hpp   |  2 +-
 reading_machine/src/Action.cpp       | 10 ++++++++--
 reading_machine/src/Config.cpp       |  2 +-
 reading_machine/src/Transition.cpp   |  2 +-
 torch_modules/src/RawInputModule.cpp |  4 ++--
 9 files changed, 35 insertions(+), 17 deletions(-)

diff --git a/common/include/util.hpp b/common/include/util.hpp
index 82620a5..1a71486 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -90,11 +90,13 @@ bool choiceWithProbability(float probability);
 
 std::string lower(const std::string & s);
 
-void lower(utf8string & s);
+void lowerInPlace(utf8string & s);
 
 utf8string lower(const utf8string & s);
 
-void lower(utf8char & c);
+utf8char lower(const utf8char & c);
+
+void lowerInPlace(utf8char & c);
 
 std::string upper(const std::string & s);
 
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 10cc040..2187a0c 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -9,6 +9,7 @@ Dict::Dict(State state)
   insert(emptyValueStr);
   insert(numberValueStr);
   insert(urlValueStr);
+  insert(separatorValueStr);
 }
 
 Dict::Dict(const char * filename, State state)
@@ -78,8 +79,10 @@ int Dict::getIndexOrInsert(const std::string & element, const std::string & pref
   if (element.empty())
     return getIndexOrInsert(emptyValueStr, prefix);
 
-  if (element.size() == 1 and util::isSeparator(util::utf8char(element)))
+  if (util::printedLength(element) == 1 and util::isSeparator(util::utf8char(element)))
+  {
     return getIndexOrInsert(separatorValueStr, prefix);
+  }
 
   if (util::isNumber(element))
     return getIndexOrInsert(numberValueStr, prefix);
diff --git a/common/src/util.cpp b/common/src/util.cpp
index ec795d8..b084f81 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -32,7 +32,7 @@ bool util::isSeparator(utf8char c)
 
 bool util::isIllegal(utf8char c)
 {
-  return c == '\n' || c == '\t';
+  return c == '\n' || c == '\t' || c == '\r';
 }
 
 bool util::isNumber(const std::string & s)
@@ -248,12 +248,12 @@ bool util::isUppercase(utf8char c)
 std::string util::lower(const std::string & s)
 {
   auto splited = util::splitAsUtf8(s);
-  lower(splited);
+  lowerInPlace(splited);
 
   return fmt::format("{}", splited);
 }
 
-void util::lower(utf8string & s)
+void util::lowerInPlace(utf8string & s)
 {
   for (auto & c : s)
   {
@@ -265,19 +265,26 @@ void util::lower(utf8string & s)
 
 util::utf8string util::lower(const utf8string & s)
 {
-  auto result = s;
-  lower(result);
+  utf8string result = s;
+  lowerInPlace(result);
 
   return result;
 }
 
-void util::lower(utf8char & c)
+void util::lowerInPlace(utf8char & c)
 {
   auto it = upper2lower.find(c);
   if (it != upper2lower.end())
     c = it->second;
 }
 
+util::utf8char util::lower(const utf8char & c)
+{
+  auto res = c;
+  lowerInPlace(res);
+  return res;
+}
+
 std::string util::upper(const std::string & s)
 {
   auto splited = util::splitAsUtf8(s);
diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp
index e376f50..a0b23a3 100644
--- a/reading_machine/include/Action.hpp
+++ b/reading_machine/include/Action.hpp
@@ -59,7 +59,7 @@ class Action
   static Action attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex);
   static Action addCharsToCol(const std::string & col, int n, Config::Object object, int relativeIndex);
   static Action ignoreCurrentCharacter();
-  static Action consumeCharacterIndex(util::utf8string consumed);
+  static Action consumeCharacterIndex(const util::utf8string & consumed);
   static Action setMultiwordIds(int multiwordSize);
   static Action split(int index);
   static Action setRootUpdateIdsEmptyStackIfSentChanged();
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index dbd49e9..42b55af 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -120,7 +120,7 @@ class Config
   String & getFirstEmpty(int colIndex, int lineIndex);
   String & getFirstEmpty(const std::string & colName, int lineIndex);
   bool hasCharacter(int letterIndex) const;
-  util::utf8char getLetter(int letterIndex) const;
+  const util::utf8char & getLetter(int letterIndex) const;
   void addToHistory(const std::string & transition);
   void addToStack(std::size_t index);
   void popStack();
diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp
index 5d4dedb..c158808 100644
--- a/reading_machine/src/Action.cpp
+++ b/reading_machine/src/Action.cpp
@@ -76,7 +76,7 @@ Action Action::setMultiwordIds(int multiwordSize)
   return {Type::Write, apply, undo, appliable};
 }
 
-Action Action::consumeCharacterIndex(util::utf8string consumed)
+Action Action::consumeCharacterIndex(const util::utf8string & consumed)
 {
   auto apply = [consumed](Config & config, Action &)
   {
@@ -97,8 +97,14 @@ Action Action::consumeCharacterIndex(util::utf8string consumed)
       return false;
 
     for (unsigned int i = 0; i < consumed.size(); i++)
-      if (!config.hasCharacter(config.getCharacterIndex()+i) or config.getLetter(config.getCharacterIndex()+i) != consumed[i])
+    {
+      if (!config.hasCharacter(config.getCharacterIndex()+i))
+        return false;
+      const util::utf8char & letter = config.getLetter(config.getCharacterIndex()+i);
+      const util::utf8char & consumedLetter = consumed[i];
+      if (util::lower(letter) != util::lower(consumedLetter))
         return false;
+    }
 
     return true;
   };
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 5b37d2c..5bd1f7d 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -384,7 +384,7 @@ bool Config::hasCharacter(int letterIndex) const
   return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput);
 }
 
-util::utf8char Config::getLetter(int letterIndex) const
+const util::utf8char & Config::getLetter(int letterIndex) const
 {
   return rawInput[letterIndex];
 }
diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp
index 68d9dbb..31dfd5d 100644
--- a/reading_machine/src/Transition.cpp
+++ b/reading_machine/src/Transition.cpp
@@ -328,7 +328,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
       return std::numeric_limits<int>::max();
 
     for (unsigned int i = 0; i < words.size(); i++)
-      if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
+      if (!config.has("FORM", config.getWordIndex()+i, 0) or util::lower(config.getConst("FORM", config.getWordIndex()+i, 0)) != words[i])
         return std::numeric_limits<int>::max();
 
     return 0;
diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp
index 66bd13d..4ed7a52 100644
--- a/torch_modules/src/RawInputModule.cpp
+++ b/torch_modules/src/RawInputModule.cpp
@@ -64,13 +64,13 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
   {
     for (int i = 0; i < leftWindow; i++)
       if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()-leftWindow+i)), ""));
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix));
       else
         contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
 
     for (int i = 0; i <= rightWindow; i++)
       if (config.hasCharacter(config.getCharacterIndex()+i))
-        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, config.getLetter(config.getCharacterIndex()+i)), ""));
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix));
       else
         contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix));
   }
-- 
GitLab