diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp
index d857a2d4fb927f9a6e9ad5c8916070ed6955d8b0..12fbdac338d92ddd4ecb56847d42e40f83e9e90c 100644
--- a/torch_modules/src/ContextualModule.cpp
+++ b/torch_modules/src/ContextualModule.cpp
@@ -158,7 +158,12 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
           dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index)));
         }
         else
-          dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index)));
+        {
+          std::string featureValue = config.getAsFeature(col, index);
+          if (w2vFile.empty())
+            featureValue = fmt::format("{}({})", col, featureValue);
+          dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue));
+        }
 
         for (auto & contextElement : context)
           contextElement.push_back(dictIndex);
@@ -176,6 +181,22 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
   }
 }
 
+torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor index)
+{
+  for (int i = 1; i < input.dim(); i++)
+    if (i != dim)
+      index = index.unsqueeze(i);
+
+  std::vector<long> expanse(input.dim());
+  for (unsigned int i = 1; i < expanse.size(); i++)
+    expanse[i] = input.size(i);
+  expanse[0] = -1;
+  expanse[dim] = -1;
+  index = index.expand(expanse);
+
+  return torch::gather(input, dim, index);
+}
+
 torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 {
   auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
@@ -183,12 +204,7 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
 
   auto out = myModule->forward(context);
 
-  std::vector<torch::Tensor> batchElems;
-
-  for (unsigned int i = 0; i < input.size(0); i++)
-    batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1}));
-
-  return torch::stack(batchElems);
+  return batchedIndexSelect(out, 1, focusedIndexes).view({input.size(0), -1});
 }
 
 void ContextualModuleImpl::registerEmbeddings()