Commit dd942e18 authored by Franck Dary's avatar Franck Dary
Browse files

Batched indexing for ContextualModule

parent 59ddb0d2
......@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
}
else
dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
{
std::string featureValue = config.getAsFeature(col, index);
if (w2vFile.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
}
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
......@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
}
}
torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
{
for (int i = 1; i < input.dim(); i++)
if (i != dim)
index = index.unsqueeze(i);
std::vector<long> expanse(input.dim());
for (unsigned int i = 1; i < expanse.size(); i++)
expanse[i] = input.size(i);
expanse[0] = -1;
expanse[dim] = -1;
index = index.expand(expanse);
return torch::gather(input, dim, index);
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
......@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
auto out = myModule->forward(context);
std::vector<torch::Tensor> batchElems;
for (unsigned int i = 0; i < input.size(0); i++)
batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
return torch::stack(batchElems);
return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
}
void ContextualModuleImpl::registerEmbeddings()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment