diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index d857a2d4fb927f9a6e9ad5c8916070ed6955d8b0..12fbdac338d92ddd4ecb56847d42e40f83e9e90c 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); } else - dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); + { + std::string featureValue = config.getAsFeature(col, index); + if (w2vFile.empty()) + featureValue = fmt::format("{}({})", col, featureValue); + dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue)); + } for (auto & contextElement : context) contextElement.push_back(dictIndex); @@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context } } +torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index) +{ + for (int i = 1; i < input.dim(); i++) + if (i != dim) + index = index.unsqueeze(i); + + std::vector<long> expanse(input.dim()); + for (unsigned int i = 1; i < expanse.size(); i++) + expanse[i] = input.size(i); + expanse[0] = -1; + expanse[dim] = -1; + index = index.expand(expanse); + + return torch::gather(input, dim, index); +} + torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) { auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1}); @@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) auto out = myModule->forward(context); - std::vector<torch::Tensor> batchElems; - - for (unsigned int i = 0; i < input.size(0); i++) - batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1})); - - return torch::stack(batchElems); + return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1}); } void ContextualModuleImpl::registerEmbeddings()