diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 12fbdac338d92ddd4ecb56847d42e40f83e9e90c..ebe386a723f69159b34206ad9787727987448aa0 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -173,11 +173,15 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context for (auto index : targetIndexes) { if (configIndex2ContextIndex.count(index)) + { for (auto & contextElement : context) - contextElement.push_back(configIndex2ContextIndex.at(index)+1); + contextElement.push_back(configIndex2ContextIndex.at(index)); + } else + { for (auto & contextElement : context) contextElement.push_back(0); + } } }