From c94fc04fb34161d40d21564a6c85e7ea5ac45bbc Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 7 Feb 2021 21:11:35 +0100
Subject: [PATCH] Added special dict value for when a feature target the child
 ofa node without one

---
 common/include/Dict.hpp                |  1 +
 common/src/Dict.cpp                    |  2 ++
 torch_modules/src/ContextModule.cpp    | 11 ++++++++---
 torch_modules/src/ContextualModule.cpp |  7 +++----
 4 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index 031d463..64a5539 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -17,6 +17,7 @@ class Dict
 
   static constexpr char const * unknownValueStr = "__unknownValue__";
   static constexpr char const * nullValueStr = "__nullValue__";
+  static constexpr char const * noChildValueStr = "__noChildValue__";
   static constexpr char const * emptyValueStr = "__emptyValue__";
   static constexpr char const * separatorValueStr = "__separatorValue__";
   static constexpr char const * numberValueStr = "__numberValue__";
diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp
index 2187a0c..eab5646 100644
--- a/common/src/Dict.cpp
+++ b/common/src/Dict.cpp
@@ -6,6 +6,7 @@ Dict::Dict(State state)
   setState(state);
   insert(unknownValueStr);
   insert(nullValueStr);
+  insert(noChildValueStr);
   insert(emptyValueStr);
   insert(numberValueStr);
   insert(urlValueStr);
@@ -300,6 +301,7 @@ bool Dict::isSpecialValue(const std::string & value)
 {
   return value == unknownValueStr
   || value == nullValueStr
+  || value == noChildValueStr
   || value == emptyValueStr
   || value == separatorValueStr
   || value == numberValueStr
diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index 0671210..cadd496 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -98,19 +98,19 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       {
         int childIndex = *std::get<2>(target);
         auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
-        int candidate = -1;
+        int candidate = -2;
 
         if (childIndex >= 0 and childIndex < (int)childs.size())
         {
           candidate = std::stoi(childs[childIndex]);
           if (candidate > baseIndex)
-            candidate = -1;
+            candidate = -2;
         }
         else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0)
         {
           candidate = std::stoi(childs[childs.size()+childIndex]);
           if (candidate < baseIndex)
-            candidate = -1;
+            candidate = -2;
         }
 
         contextIndexes.emplace_back(candidate);
@@ -128,6 +128,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
         for (auto & contextElement : context)
           contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
       }
+      else if (index == -2)
+      {
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
+      }
       else
       {
         int dictIndex;
diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index 6338c0c..75aed8b 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -94,7 +94,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
   std::vector<long> targetIndexes;
   std::map<long,long> configIndex2ContextIndex;
 
-  contextIndexes.emplace_back(-2);
+  contextIndexes.emplace_back(-1);
 
   for (long i = window.first; i <= window.second; i++)
   {
@@ -117,7 +117,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
       {
         int childIndex = *std::get<2>(target);
         auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|');
-        int candidate = -1;
+        int candidate = -2;
 
         if (childIndex >= 0 and childIndex < (int)childs.size())
           candidate = std::stoi(childs[childIndex]);
@@ -141,9 +141,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
       }
       else if (index == -2)
       {
-        //TODO maybe change this to a unique value like Dict::noneValueStr
         for (auto & contextElement : context)
-          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col));
+          contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
       }
       else
       {
-- 
GitLab