Skip to content
Snippets Groups Projects
Commit 92d8266b authored by Franck Dary's avatar Franck Dary
Browse files

Added out of bound special value

parent f4f14edf
No related branches found
No related tags found
No related merge requests found
......@@ -116,8 +116,14 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
contextIndexes.emplace_back(candidate);
}
}
else
else if (std::get<0>(target) == Config::Object::Stack)
{
contextIndexes.emplace_back(-1);
}
else
{
contextIndexes.emplace_back(-3);
}
for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
......@@ -133,6 +139,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
}
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
}
else
{
int dictIndex;
......
......@@ -84,7 +84,7 @@ std::size_t ContextualModuleImpl::getOutputSize()
std::size_t ContextualModuleImpl::getInputSize()
{
return columns.size()*(2+window.second-window.first)+targets.size();
return columns.size()*(4+window.second-window.first)+targets.size();
}
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
......@@ -95,6 +95,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
std::map<long,long> configIndex2ContextIndex;
contextIndexes.emplace_back(-1);
contextIndexes.emplace_back(-2);
contextIndexes.emplace_back(-3);
for (long i = window.first; i <= window.second; i++)
{
......@@ -104,7 +106,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
}
else
contextIndexes.emplace_back(-1);
contextIndexes.emplace_back(-3);
}
for (auto & target : targets)
......@@ -144,6 +146,11 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
}
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
}
else
{
int dictIndex;
......@@ -177,7 +184,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
else
{
for (auto & contextElement : context)
contextElement.push_back(0);
{
// -1 == doesn't exist (s.0 when no stack)
if (index == -1)
contextElement.push_back(0);
// -2 == nochild
else if (index == -2)
contextElement.push_back(1);
// other == out of context bounds
else
contextElement.push_back(2);
}
}
}
}
......@@ -200,7 +217,7 @@ torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor ind
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1});
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (4+window.second-window.first), -1});
auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());
auto out = myModule->forward(context);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment