Skip to content
Snippets Groups Projects
Commit 29883154 authored by Franck Dary's avatar Franck Dary
Browse files

Made RTLSTMNetwork batched

parent 1da32f54
No related branches found
No related tags found
No related merge requests found
......@@ -23,49 +23,59 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
input = input.squeeze();
if (input.dim() != 1)
util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1);
auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1));
auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds});
auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
if (input.dim() == 1)
input = input.unsqueeze(0);
auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1);
auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1));
auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1));
auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0);
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze();
std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
for (unsigned int i = 0; i < computeOrder.size(0); i++)
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
std::vector<std::map<int, torch::Tensor>> treeRepresentations;
for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
{
int index = computeOrder[i].item<int>();
if (index == -1)
break;
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
treeRepresentations.emplace_back();
for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
{
int child = childs[index][childIndex].item<int>();
if (child == -1)
int index = computeOrder[batch][i].item<int>();
if (index == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0));
std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{
int child = childs[batch][index][childIndex].item<int>();
if (child == -1)
break;
inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[batch][index] = lstmOut;
}
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[index] = lstmOut;
}
std::vector<torch::Tensor> focusedTrees;
for (unsigned int i = 0; i < focusedIndexes.size(0); i++)
std::vector<torch::Tensor> representations;
for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
{
int index = focusedIndexes[i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[index]);
focusedTrees.clear();
for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
{
int index = focusedIndexes[batch][i].item<int>();
if (index == -1)
focusedTrees.emplace_back(nullTree);
else
focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
}
representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
}
auto representation = torch::cat(focusedTrees, 0);
auto representation = torch::cat(representations, 0);
return linear2(torch::relu(linear1(representation)));
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment