From 0ee399631e627ff654863db3027f003790fe95c1 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 11 Apr 2020 12:01:11 +0200 Subject: [PATCH] Added gov to DepthLayerTreeEmbedding --- torch_modules/src/DepthLayerTreeEmbedding.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index aa4c7ae..f215960 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -70,10 +70,17 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & newChilds.insert(newChilds.end(), val.begin(), val.end()); } childs = newChilds; + if (depth == 0) + { + newChilds.clear(); + auto gov = config.has(0,index,0) ? config.getAsFeature(Config::headColName, index).get() : "-1"; + newChilds.emplace_back(util::isEmpty(gov) ? "-1" : gov); + newChilds.insert(newChilds.end(), childs.begin(), childs.end()); + } for (int i = 0; i < maxElemPerDepth[depth]; i++) for (auto & col : columns) - if (i < (int)childs.size() and config.has(col, std::stoi(childs[i]), 0)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(childs[i])))); + if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) + contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])))); else contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } -- GitLab