Skip to content
Snippets Groups Projects
Commit dd942e18 authored by Franck Dary's avatar Franck Dary
Browse files

Batched indexing for ContextualModule

parent 59ddb0d2
No related branches found
No related tags found
No related merge requests found
......@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
}
else
dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
{
std::string featureValue = config.getAsFeature(col, index);
if (w2vFile.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
}
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
......@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
}
}
torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
{
for (int i = 1; i < input.dim(); i++)
if (i != dim)
index = index.unsqueeze(i);
std::vector<long> expanse(input.dim());
for (unsigned int i = 1; i < expanse.size(); i++)
expanse[i] = input.size(i);
expanse[0] = -1;
expanse[dim] = -1;
index = index.expand(expanse);
return torch::gather(input, dim, index);
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
......@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
auto out = myModule->forward(context);
std::vector<torch::Tensor> batchElems;
for (unsigned int i = 0; i < input.size(0); i++)
batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
return torch::stack(batchElems);
return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
}
void ContextualModuleImpl::registerEmbeddings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment