Skip to content
Snippets Groups Projects
Commit 29883154 authored by Franck Dary's avatar Franck Dary
Browse files

Made RTLSTMNetwork batched

parent 1da32f54
No related branches found
No related tags found
No related merge requests found
...@@ -23,49 +23,59 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor ...@@ -23,49 +23,59 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBor
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{ {
input = input.squeeze(); if (input.dim() == 1)
if (input.dim() != 1) input = input.unsqueeze(0);
util::myThrow(fmt::format("Does not support batched input (dim()={})", input.dim()));
auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto focusedIndexes = input.narrow(0, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1);
auto computeOrder = input.narrow(0, focusedIndexes.size(0), leftBorder+rightBorder+1); auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1));
auto childsFlat = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0), maxNbChilds*(leftBorder+rightBorder+1)); auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
auto childs = torch::reshape(childsFlat, {computeOrder.size(0), maxNbChilds}); auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1));
auto wordIndexes = input.narrow(0, focusedIndexes.size(0)+computeOrder.size(0)+childsFlat.size(0), columns.size()*(leftBorder+rightBorder+1));
auto baseEmbeddings = wordEmbeddings(wordIndexes); auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {(int)baseEmbeddings.size(0)/(int)columns.size(), (int)baseEmbeddings.size(1)*(int)columns.size()}).unsqueeze(0); auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output.squeeze(); auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
std::vector<torch::Tensor> treeRepresentations(vectorRepresentations.size(0), nullTree);
for (unsigned int i = 0; i < computeOrder.size(0); i++) std::vector<std::map<int, torch::Tensor>> treeRepresentations;
for (unsigned int batch = 0; batch < computeOrder.size(0); batch++)
{
treeRepresentations.emplace_back();
for (unsigned int i = 0; i < computeOrder[batch].size(0); i++)
{ {
int index = computeOrder[i].item<int>(); int index = computeOrder[batch][i].item<int>();
if (index == -1) if (index == -1)
break; break;
std::vector<torch::Tensor> inputVector; std::vector<torch::Tensor> inputVector;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], S}, 0)); inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0));
for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++)
{ {
int child = childs[index][childIndex].item<int>(); int child = childs[batch][index][childIndex].item<int>();
if (child == -1) if (child == -1)
break; break;
inputVector.emplace_back(torch::cat({vectorRepresentations[index], treeRepresentations[child]}, 0)); inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0));
} }
auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0);
auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze();
treeRepresentations[index] = lstmOut; treeRepresentations[batch][index] = lstmOut;
}
} }
std::vector<torch::Tensor> focusedTrees; std::vector<torch::Tensor> focusedTrees;
for (unsigned int i = 0; i < focusedIndexes.size(0); i++) std::vector<torch::Tensor> representations;
for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++)
{ {
int index = focusedIndexes[i].item<int>(); focusedTrees.clear();
for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++)
{
int index = focusedIndexes[batch][i].item<int>();
if (index == -1) if (index == -1)
focusedTrees.emplace_back(nullTree); focusedTrees.emplace_back(nullTree);
else else
focusedTrees.emplace_back(treeRepresentations[index]); focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree);
}
representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0));
} }
auto representation = torch::cat(focusedTrees, 0); auto representation = torch::cat(representations, 0);
return linear2(torch::relu(linear1(representation))); return linear2(torch::relu(linear1(representation)));
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment