Skip to content
Snippets Groups Projects
Commit dd942e18 authored by Franck Dary's avatar Franck Dary
Browse files

Batched indexing for ContextualModule

parent 59ddb0d2
No related branches found
No related tags found
No related merge requests found
...@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
} }
else else
dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); {
std::string featureValue = config.getAsFeature(col, index);
if (w2vFile.empty())
featureValue = fmt::format("{}({})", col, featureValue);
dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
}
for (auto & contextElement : context) for (auto & contextElement : context)
contextElement.push_back(dictIndex); contextElement.push_back(dictIndex);
...@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
} }
} }
torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
{
for (int i = 1; i < input.dim(); i++)
if (i != dim)
index = index.unsqueeze(i);
std::vector<long> expanse(input.dim());
for (unsigned int i = 1; i < expanse.size(); i++)
expanse[i] = input.size(i);
expanse[0] = -1;
expanse[dim] = -1;
index = index.expand(expanse);
return torch::gather(input, dim, index);
}
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{ {
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1}); auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
...@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) ...@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
auto out = myModule->forward(context); auto out = myModule->forward(context);
std::vector<torch::Tensor> batchElems; return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
for (unsigned int i = 0; i < input.size(0); i++)
batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
return torch::stack(batchElems);
} }
void ContextualModuleImpl::registerEmbeddings() void ContextualModuleImpl::registerEmbeddings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment