From 8023999cd784287a0fc9c208942d0ea0772267aa Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 12 Apr 2020 13:06:12 +0200
Subject: [PATCH] DepthLayerTreeEmbedding no longer contains gov

---
 torch_modules/src/DepthLayerTreeEmbedding.cpp | 8 +-------
 torch_modules/src/LSTMNetwork.cpp             | 2 +-
 2 files changed, 2 insertions(+), 8 deletions(-)

diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp
index f215960..6e1342a 100644
--- a/torch_modules/src/DepthLayerTreeEmbedding.cpp
+++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp
@@ -70,13 +70,7 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> &
             newChilds.insert(newChilds.end(), val.begin(), val.end());
           }
         childs = newChilds;
-        if (depth == 0)
-        {
-          newChilds.clear();
-          auto gov = config.has(0,index,0) ? config.getAsFeature(Config::headColName, index).get() : "-1";
-          newChilds.emplace_back(util::isEmpty(gov) ? "-1" : gov);
-          newChilds.insert(newChilds.end(), childs.begin(), childs.end());
-        }
+
         for (int i = 0; i < maxElemPerDepth[depth]; i++)
           for (auto & col : columns)
             if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0))
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 7a25024..bf77626 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -26,7 +26,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
   if (!treeEmbeddingColumns.empty())
   {
     hasTreeEmbedding = true;
-    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptionsAll));
+    treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions));
     treeEmbedding->setFirstInputIndex(currentInputSize);
     currentOutputSize += treeEmbedding->getOutputSize();
     currentInputSize += treeEmbedding->getInputSize();
-- 
GitLab