diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 364f2cbf9e7bde85b9b007b9da4303ffe6533956..2e13833591fbac8fe99125e122d093d6b1611971 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -85,12 +85,22 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c { int childIndex = *std::get<2>(target); auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); + int candidate = -1; + if (childIndex >= 0 and childIndex < (int)childs.size()) - contextIndexes.emplace_back(std::stoi(childs[childIndex])); + { + candidate = std::stoi(childs[childIndex]); + if (candidate > baseIndex) + candidate = -1; + } else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0) - contextIndexes.emplace_back(std::stoi(childs[childs.size()+childIndex])); - else - contextIndexes.emplace_back(-1); + { + candidate = std::stoi(childs[childs.size()+childIndex]); + if (candidate < baseIndex) + candidate = -1; + } + + contextIndexes.emplace_back(candidate); } } else