diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 71c79b35cd928fdae4a5acf5d69766074aad55e9..2981b744a4aec17fbe3c996f0559058b903088cc 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -31,19 +31,19 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); - auto curIndex = wordIndexes.size(1); - std::vector<torch::Tensor> cnnOutputs; + auto elementsEmbeddings = wordEmbeddings(input.narrow(1, wordIndexes.size(1), input.size(1)-wordIndexes.size(1))); + std::vector<torch::Tensor> cnnOutputs; + auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { - long nbElements = input[0][curIndex].item<long>(); - - curIndex++; + long nbElements = maxNbElements[i]; for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++) { - cnnOutputs.emplace_back(cnns[i](wordEmbeddings(input.narrow(1, curIndex, nbElements)).unsqueeze(1))); + auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1); curIndex += nbElements; + cnnOutputs.emplace_back(cnns[i](cnnInput)); } } @@ -60,8 +60,8 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> context; - for (auto & col : columns) - for (auto index : contextIndexes) + for (auto index : contextIndexes) + for (auto & col : columns) if (index == -1) context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); else @@ -71,8 +71,6 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c { auto & col = focusedColumns[colIndex]; - context.push_back(maxNbElements[colIndex]); - std::vector<int> focusedIndexes; for (auto relIndex : focusedBufferIndexes) {