diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 8289c13c3a87040be35ddf3b1098a092a7e37b83..ffea1d0c056252c1389b31fe024fc7de6c1c731b 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -148,10 +148,19 @@ void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & con if (col == Config::idColName) { std::string value; - if (config.isMultiwordPredicted(index)) + if (config.getAsFeature(Config::idColName, index).empty()) + value = "empty"; + else if (config.isMultiwordPredicted(index)) value = "multiword"; + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + value = "part"; else if (config.isTokenPredicted(index)) value = "token"; + else + { + config.printForDebug(stderr); + util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index)); + } dictIndex = dict.getIndexOrInsert(value, col); } else diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index c0a717c14ed245a84ea908ccc95685c5c6260a41..d4356484dae267d30442acdc965e1d639dfd9bff 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -156,11 +156,19 @@ void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & if (col == Config::idColName) { std::string value; - if (config.isMultiwordPredicted(index)) + if (config.getAsFeature(Config::idColName, index).empty()) + value = "empty"; + else if (config.isMultiwordPredicted(index)) value = "multiword"; + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + value = "part"; else if (config.isTokenPredicted(index)) value = "token"; - dictIndex = dict.getIndexOrInsert(value, col); + else + { + config.printForDebug(stderr); + util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index)); + } } else { diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index a7af65a275a422d867e5fedb64e47fbd19d6ab87..77f8c2647be31e64da1db25d045b8378f6e95564 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -94,16 +94,16 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (config.hasStack(index)) focusedIndexes.emplace_back(config.getStack(index)); else - focusedIndexes.emplace_back(-1); + focusedIndexes.emplace_back(-2); int insertIndex = 0; for (auto index : focusedIndexes) { - if (index == -1) + if (index == -1 or index == -2) { for (int i = 0; i < maxNbElements; i++) { - context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(index == -1 ? Dict::oobValueStr : Dict::nullValueStr, column); insertIndex++; } continue; @@ -113,13 +113,11 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (column == "FORM") { auto asUtf8 = util::splitAsUtf8(func(std::string(config.getAsFeature(column, index)))); - - //TODO don't use nullValueStr here for (int i = 0; i < maxNbElements; i++) if (i < (int)asUtf8.size()) elements.emplace_back(fmt::format("{}", asUtf8[i])); else - elements.emplace_back(Dict::nullValueStr); + elements.emplace_back("<padding>"); } else if (column == "FEATS") { @@ -129,16 +127,18 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config if (i < (int)splited.size()) elements.emplace_back(splited[i]); else - elements.emplace_back(Dict::nullValueStr); + elements.emplace_back("<padding>"); } - else if (column == "ID") + else if (column == Config::idColName) { - if (config.isTokenPredicted(index)) - elements.emplace_back("TOKEN"); + if (config.getAsFeature(Config::idColName, index).empty()) + elements.emplace_back("empty"); else if (config.isMultiwordPredicted(index)) - elements.emplace_back("MULTIWORD"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("EMPTYNODE"); + elements.emplace_back("multiword"); + else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1) + elements.emplace_back("part"); + else if (config.isTokenPredicted(index)) + elements.emplace_back("token"); } else if (column == "EOS") {