From 8023999cd784287a0fc9c208942d0ea0772267aa Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 12 Apr 2020 13:06:12 +0200 Subject: [PATCH] DepthLayerTreeEmbedding no longer contains gov --- torch_modules/src/DepthLayerTreeEmbedding.cpp | 8 +------- torch_modules/src/LSTMNetwork.cpp | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp index f215960..6e1342a 100644 --- a/torch_modules/src/DepthLayerTreeEmbedding.cpp +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -70,13 +70,7 @@ void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & newChilds.insert(newChilds.end(), val.begin(), val.end()); } childs = newChilds; - if (depth == 0) - { - newChilds.clear(); - auto gov = config.has(0,index,0) ? config.getAsFeature(Config::headColName, index).get() : "-1"; - newChilds.emplace_back(util::isEmpty(gov) ? "-1" : gov); - newChilds.insert(newChilds.end(), childs.begin(), childs.end()); - } + for (int i = 0; i < maxElemPerDepth[depth]; i++) for (auto & col : columns) if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 7a25024..bf77626 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -26,7 +26,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: if (!treeEmbeddingColumns.empty()) { hasTreeEmbedding = true; - treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptionsAll)); + treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions)); treeEmbedding->setFirstInputIndex(currentInputSize); currentOutputSize += treeEmbedding->getOutputSize(); currentInputSize += treeEmbedding->getInputSize(); -- GitLab