diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index cadd4960ecce085800f568a0f64f8c2d321ec7ef..f457f314b3595bf0c51d438893c3743609605a98 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -116,8 +116,14 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c contextIndexes.emplace_back(candidate); } } - else + else if (std::get<0>(target) == Config::Object::Stack) + { contextIndexes.emplace_back(-1); + } + else + { + contextIndexes.emplace_back(-3); + } for (auto index : contextIndexes) for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++) @@ -133,6 +139,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); } + else if (index == -3) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col)); + } else { int dictIndex; diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 75aed8bfe95d090976954b224506dbc252b376d5..f5f8562e26b811af74966ec8343eeaacb6774e2f 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -84,7 +84,7 @@ std::size_t ContextualModuleImpl::getOutputSize() std::size_t ContextualModuleImpl::getInputSize() { - return columns.size()*(2+window.second-window.first)+targets.size(); + return columns.size()*(4+window.second-window.first)+targets.size(); } void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) @@ -95,6 +95,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context std::map<long,long> configIndex2ContextIndex; contextIndexes.emplace_back(-1); + contextIndexes.emplace_back(-2); + contextIndexes.emplace_back(-3); for (long i = window.first; i <= window.second; i++) { @@ -104,7 +106,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1; } else - contextIndexes.emplace_back(-1); + contextIndexes.emplace_back(-3); } for (auto & target : targets) @@ -144,6 +146,11 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context for (auto & contextElement : context) contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); } + else if (index == -3) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col)); + } else { int dictIndex; @@ -177,7 +184,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context else { for (auto & contextElement : context) - contextElement.push_back(0); + { + // -1 == doesn't exist (s.0 when no stack) + if (index == -1) + contextElement.push_back(0); + // -2 == nochild + else if (index == -2) + contextElement.push_back(1); + // other == out of context bounds + else + contextElement.push_back(2); + } } } } @@ -200,7 +217,7 @@ torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor ind torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) { - auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1}); + auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (4+window.second-window.first), -1}); auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size()); auto out = myModule->forward(context);