diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp index b59892c7770404e64b63ea653b600088b4aa1f33..75ded3a96f15d532566716226d056969441d6287 100644 --- a/torch_modules/src/RTLSTMNetwork.cpp +++ b/torch_modules/src/RTLSTMNetwork.cpp @@ -23,49 +23,59 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) { - input = input.squeeze(); - if (input.dim() != 1) - util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim())); - - auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); - auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1); - auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1)); - auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds}); - auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1)); + if (input.dim() == 1) + input = input.unsqueeze(0); + + auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); + auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1); + auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1)); + auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds}); + auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1)); auto baseEmbeddings = wordEmbeddings(wordIndexes); - auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0); - auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze(); - std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree); - for (unsigned int i = 0; i < computeOrder.size(0); i++) + auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()}); + auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output; + + std::vector<std::map<int, torch::Tensor>> treeRepresentations; + for (unsigned int batch = 0; batch < computeOrder.size(0); batch++) { - int index = computeOrder[i].item<int>(); - if (index == -1) - break; - std::vector<torch::Tensor> inputVector; - inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0)); - for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) + treeRepresentations.emplace_back(); + for (unsigned int i = 0; i < computeOrder[batch].size(0); i++) { - int child = childs[index][childIndex].item<int>(); - if (child == -1) + int index = computeOrder[batch][i].item<int>(); + if (index == -1) break; - inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0)); + std::vector<torch::Tensor> inputVector; + inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0)); + for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) + { + int child = childs[batch][index][childIndex].item<int>(); + if (child == -1) + break; + inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0)); + } + auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); + auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); + treeRepresentations[batch][index] = lstmOut; } - auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); - auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); - treeRepresentations[index] = lstmOut; } std::vector<torch::Tensor> focusedTrees; - for (unsigned int i = 0; i < focusedIndexes.size(0); i++) + std::vector<torch::Tensor> representations; + for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++) { - int index = focusedIndexes[i].item<int>(); - if (index == -1) - focusedTrees.emplace_back(nullTree); - else - focusedTrees.emplace_back(treeRepresentations[index]); + focusedTrees.clear(); + for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++) + { + int index = focusedIndexes[batch][i].item<int>(); + if (index == -1) + focusedTrees.emplace_back(nullTree); + else + focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree); + } + representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0)); } - auto representation = torch::cat(focusedTrees, 0); + auto representation = torch::cat(representations, 0); return linear2(torch::relu(linear1(representation))); }