From 61e77a5a6428c8018ffb61eff98b2139e419f8c6 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 1 Apr 2021 15:50:50 +0200 Subject: [PATCH] Changed the encoding of features in certain modules --- torch_modules/src/ContextModule.cpp | 11 +++++++++- torch_modules/src/ContextualModule.cpp | 12 +++++++++-- torch_modules/src/FocusedColumnModule.cpp | 26 +++++++++++------------ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 8289c13..ffea1d0 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -148,10 +148,19 @@ void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & con if (col == Config::idColName) { std::string value; - if (config.isMultiwordPredicted(index)) + if (config.getAsFeature(Config::idColName, index).empty()) + value = "empty"; + else if (config.isMultiwordPredicted(index)) value = "multiword"; + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + value = "part"; else if (config.isTokenPredicted(index)) value = "token"; + else + { + config.printForDebug(stderr); + util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index)); + } dictIndex = dict.getIndexOrInsert(value, col); } else diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index c0a717c..d435648 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -156,11 +156,19 @@ void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & if (col == Config::idColName) { std::string value; - if (config.isMultiwordPredicted(index)) + if (config.getAsFeature(Config::idColName, index).empty()) + value = "empty"; + else if (config.isMultiwordPredicted(index)) value = "multiword"; + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + value = "part"; else if (config.isTokenPredicted(index)) value = "token"; - dictIndex = dict.getIndexOrInsert(value, col); + else + { + config.printForDebug(stderr); + util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index)); + } } else { diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index a7af65a..77f8c26 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -94,16 +94,16 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (config.hasStack(index)) focusedIndexes.emplace_back(config.getStack(index)); else - focusedIndexes.emplace_back(-1); + focusedIndexes.emplace_back(-2); int insertIndex = 0; for (auto index : focusedIndexes) { - if (index == -1) + if (index == -1 or index == -2) { for (int i = 0; i < maxNbElements; i++) { - context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(index == -1 ? Dict::oobValueStr : Dict::nullValueStr, column); insertIndex++; } continue; @@ -113,13 +113,11 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (column == "FORM") { auto asUtf8 = util::splitAsUtf8(func(std::string(config.getAsFeature(column, index)))); - - //TODO don't use nullValueStr here for (int i = 0; i < maxNbElements; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("{}", asUtf8[i])); else - elements.emplace_back(Dict::nullValueStr); + elements.emplace_back("<padding>"); } else if (column == "FEATS") { @@ -129,16 +127,18 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (i < (int)splited.size()) elements.emplace_back(splited[i]); else - elements.emplace_back(Dict::nullValueStr); + elements.emplace_back("<padding>"); } - else if (column == "ID") + else if (column == Config::idColName) { - if (config.isTokenPredicted(index)) - elements.emplace_back("TOKEN"); + if (config.getAsFeature(Config::idColName, index).empty()) + elements.emplace_back("empty"); else if (config.isMultiwordPredicted(index)) - elements.emplace_back("MULTIWORD"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("EMPTYNODE"); + elements.emplace_back("multiword"); + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + elements.emplace_back("part"); + else if (config.isTokenPredicted(index)) + elements.emplace_back("token"); } else if (column == "EOS") { -- GitLab