From d80db35406f40635a5a304e3041789d834f2142d Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 5 Mar 2020 13:49:56 +0100 Subject: [PATCH] Fixed CNNNetwork --- torch_modules/src/CNNNetwork.cpp | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 71c79b3..2981b74 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -31,19 +31,19 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); - auto curIndex = wordIndexes.size(1); - std::vector<torch::Tensor> cnnOutputs; + auto elementsEmbeddings = wordEmbeddings(input.narrow(1, wordIndexes.size(1), input.size(1)-wordIndexes.size(1))); + std::vector<torch::Tensor> cnnOutputs; + auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { - long nbElements = input[0][curIndex].item<long>(); - - curIndex++; + long nbElements = maxNbElements[i]; for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) { - cnnOutputs.emplace_back(cnns[i](wordEmbeddings(input.narrow(1, curIndex, nbElements)).unsqueeze(1))); + auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1); curIndex += nbElements; + cnnOutputs.emplace_back(cnns[i](cnnInput)); } } @@ -60,8 +60,8 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; - for (auto & col : columns) - for (auto index : contextIndexes) + for (auto index : contextIndexes) + for (auto & col : columns) if (index == -1) context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else @@ -71,8 +71,6 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c { auto & col = focusedColumns[colIndex]; - context.push_back(maxNbElements[colIndex]); - std::vector<int> focusedIndexes; for (auto relIndex : focusedBufferIndexes) { -- GitLab