From 9f7e5b50d2a9fca84ff0e6a234dfbe0071d3a704 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 15 Jun 2020 23:22:53 +0200 Subject: [PATCH] Added special case for EOS and ID in ContextModule --- torch_modules/src/ContextModule.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index a6e1c85..4b973c0 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -107,7 +107,24 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c } else { - int dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); + int dictIndex; + if (col == Config::idColName) + { + std::string value; + if (config.isCommentPredicted(index)) + value = "ID(comment)"; + else if (config.isMultiwordPredicted(index)) + value = "ID(multiword)"; + else if (config.isTokenPredicted(index)) + value = "ID(token)"; + dictIndex = dict.getIndexOrInsert(value); + } + else if (col == Config::EOSColName) + { + dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + } + else + dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); for (auto & contextElement : context) contextElement.push_back(dictIndex); -- GitLab