diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index aa4c7aef388105493ce3315d7761acb96b7693ba..f21596038c3eb4c87527d0b47f79a8b272e5b864 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -70,10 +70,17 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> &
             newChilds.insert(newChilds.end(), val.begin(), val.end());
           }
         childs = newChilds;
+        if (depth == 0)
+        {
+          newChilds.clear();
+          auto gov = config.has(0,index,0) ? config.getAsFeature(Config::headColName, index).get() : "-1";
+          newChilds.emplace_back(util::isEmpty(gov) ? "-1" : gov);
+          newChilds.insert(newChilds.end(), childs.begin(), childs.end());
+        }
         for (int i = 0; i < maxElemPerDepth[depth]; i++)
           for (auto & col : columns)
-            if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0))
-              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i]))));
+            if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
+              contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i]))));
             else
               contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
       }