diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index aa4c7aef388105493ce3315d7761acb96b7693ba..f21596038c3eb4c87527d0b47f79a8b272e5b864 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -70,10 +70,17 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & newChilds.insert(newChilds.end(), val.begin(), val.end()); } childs = newChilds; + if (depth == 0) + { + newChilds.clear(); + auto gov = config.has(0,index,0) ? config.getAsFeature(Config::headColName, index).get() : "-1"; + newChilds.emplace_back(util::isEmpty(gov) ? "-1" : gov); + newChilds.insert(newChilds.end(), childs.begin(), childs.end()); + } for (int i = 0; i < maxElemPerDepth[depth]; i++) for (auto & col : columns) - if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i])))); + if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) + contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])))); else contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); }