diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index 031d463f417047a3529f17f7b2c15c102e04c50c..64a55399681c0fbc3c7bcce090d2d4be987c8d9e 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -17,6 +17,7 @@ class Dict static constexpr char const * unknownValueStr = "__unknownValue__"; static constexpr char const * nullValueStr = "__nullValue__"; + static constexpr char const * noChildValueStr = "__noChildValue__"; static constexpr char const * emptyValueStr = "__emptyValue__"; static constexpr char const * separatorValueStr = "__separatorValue__"; static constexpr char const * numberValueStr = "__numberValue__"; diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index 2187a0ca0738b8c11c3ad4b7f81f1eef48f79c6b..eab564683a7afb11900d3b6496b960ae6cb71f71 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -6,6 +6,7 @@ Dict::Dict(State state) setState(state); insert(unknownValueStr); insert(nullValueStr); + insert(noChildValueStr); insert(emptyValueStr); insert(numberValueStr); insert(urlValueStr); @@ -300,6 +301,7 @@ bool Dict::isSpecialValue(const std::string & value) { return value == unknownValueStr || value == nullValueStr + || value == noChildValueStr || value == emptyValueStr || value == separatorValueStr || value == numberValueStr diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 06712108450a7607b4cab3ddd59cae9407868c4b..cadd4960ecce085800f568a0f64f8c2d321ec7ef 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -98,19 +98,19 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { int childIndex = *std::get<2>(target); auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); - int candidate = -1; + int candidate = -2; if (childIndex >= 0 and childIndex < (int)childs.size()) { candidate = std::stoi(childs[childIndex]); if (candidate > baseIndex) - candidate = -1; + candidate = -2; } else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0) { candidate = std::stoi(childs[childs.size()+childIndex]); if (candidate < baseIndex) - candidate = -1; + candidate = -2; } contextIndexes.emplace_back(candidate); @@ -128,6 +128,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); } + else if (index == -2) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); + } else { int dictIndex; diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 6338c0c07403bde095eb1d05466b7f17bb54b189..75aed8bfe95d090976954b224506dbc252b376d5 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -94,7 +94,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context std::vector<long> targetIndexes; std::map<long,long> configIndex2ContextIndex; - contextIndexes.emplace_back(-2); + contextIndexes.emplace_back(-1); for (long i = window.first; i <= window.second; i++) { @@ -117,7 +117,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context { int childIndex = *std::get<2>(target); auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); - int candidate = -1; + int candidate = -2; if (childIndex >= 0 and childIndex < (int)childs.size()) candidate = std::stoi(childs[childIndex]); @@ -141,9 +141,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context } else if (index == -2) { - //TODO maybe change this to a unique value like Dict::noneValueStr for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); + contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); } else {