From 9f7e5b50d2a9fca84ff0e6a234dfbe0071d3a704 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 15 Jun 2020 23:22:53 +0200
Subject: [PATCH] Added special case for EOS and ID in ContextModule

---
 torch_modules/src/ContextModule.cpp | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp
index a6e1c85..4b973c0 100644
--- a/torch_modules/src/ContextModule.cpp
+++ b/torch_modules/src/ContextModule.cpp
@@ -107,7 +107,24 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
       }
       else
       {
-        int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
+        int dictIndex;
+        if (col == Config::idColName)
+        {
+          std::string value;
+          if (config.isCommentPredicted(index))
+            value = "ID(comment)";
+          else if (config.isMultiwordPredicted(index))
+            value = "ID(multiword)";
+          else if (config.isTokenPredicted(index))
+            value = "ID(token)";
+          dictIndex = dict.getIndexOrInsert(value);
+        }
+        else if (col == Config::EOSColName)
+        {
+          dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
+        }
+        else
+          dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
 
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
-- 
GitLab