From dd942e181084b2f285981a5195770670033866fa Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 29 Jun 2020 14:08:11 +0200
Subject: [PATCH] Batched indexing for ContextualModule

---
 torch_modules/src/ContextualModule.cpp | 30 ++++++++++++++++++++------
 1 file changed, 23 insertions(+), 7 deletions(-)

diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index d857a2d..12fbdac 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
           dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
         }
         else
-          dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
+        {
+          std::string featureValue = config.getAsFeature(col, index);
+          if (w2vFile.empty())
+            featureValue = fmt::format("{}({})", col, featureValue);
+          dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
+        }
 
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
   }
 }
 
+torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
+{
+  for (int i = 1; i < input.dim(); i++)
+    if (i != dim)
+      index = index.unsqueeze(i);
+
+  std::vector<long> expanse(input.dim());
+  for (unsigned int i = 1; i < expanse.size(); i++)
+    expanse[i] = input.size(i);
+  expanse[0] = -1;
+  expanse[dim] = -1;
+  index = index.expand(expanse);
+
+  return torch::gather(input, dim, index);
+}
+
 torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 {
   auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 
   auto out = myModule->forward(context);
 
-  std::vector<torch::Tensor> batchElems;
-
-  for (unsigned int i = 0; i < input.size(0); i++)
-    batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
-
-  return torch::stack(batchElems);
+  return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
 }
 
 void ContextualModuleImpl::registerEmbeddings()
-- 
GitLab