Skip to content
Snippets Groups Projects
Commit 92d8266b authored by Franck Dary's avatar Franck Dary
Browse files

Added out of bound special value

parent f4f14edf
Branches
No related tags found
No related merge requests found
...@@ -116,8 +116,14 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c ...@@ -116,8 +116,14 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
contextIndexes.emplace_back(candidate); contextIndexes.emplace_back(candidate);
} }
} }
else else if (std::get<0>(target) == Config::Object::Stack)
{
contextIndexes.emplace_back(-1); contextIndexes.emplace_back(-1);
}
else
{
contextIndexes.emplace_back(-3);
}
for (auto index : contextIndexes) for (auto index : contextIndexes)
for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++) for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++)
...@@ -133,6 +139,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c ...@@ -133,6 +139,11 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
for (auto & contextElement : context) for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
} }
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
}
else else
{ {
int dictIndex; int dictIndex;
......
...@@ -84,7 +84,7 @@ std::size_t ContextualModuleImpl::getOutputSize() ...@@ -84,7 +84,7 @@ std::size_t ContextualModuleImpl::getOutputSize()
std::size_t ContextualModuleImpl::getInputSize() std::size_t ContextualModuleImpl::getInputSize()
{ {
return columns.size()*(2+window.second-window.first)+targets.size(); return columns.size()*(4+window.second-window.first)+targets.size();
} }
void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config)
...@@ -95,6 +95,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -95,6 +95,8 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
std::map<long,long> configIndex2ContextIndex; std::map<long,long> configIndex2ContextIndex;
contextIndexes.emplace_back(-1); contextIndexes.emplace_back(-1);
contextIndexes.emplace_back(-2);
contextIndexes.emplace_back(-3);
for (long i = window.first; i <= window.second; i++) for (long i = window.first; i <= window.second; i++)
{ {
...@@ -104,7 +106,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -104,7 +106,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1; configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1;
} }
else else
contextIndexes.emplace_back(-1); contextIndexes.emplace_back(-3);
} }
for (auto & target : targets) for (auto & target : targets)
...@@ -144,6 +146,11 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -144,6 +146,11 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
for (auto & contextElement : context) for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col));
} }
else if (index == -3)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col));
}
else else
{ {
int dictIndex; int dictIndex;
...@@ -177,7 +184,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -177,7 +184,17 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context
else else
{ {
for (auto & contextElement : context) for (auto & contextElement : context)
{
// -1 == doesn't exist (s.0 when no stack)
if (index == -1)
contextElement.push_back(0); contextElement.push_back(0);
// -2 == nochild
else if (index == -2)
contextElement.push_back(1);
// other == out of context bounds
else
contextElement.push_back(2);
}
} }
} }
} }
...@@ -200,7 +217,7 @@ torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor ind ...@@ -200,7 +217,7 @@ torch::Tensor batchedIndexSelect(torch::Tensor input, int dim, torch::Tensor ind
torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
{ {
auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1}); auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (4+window.second-window.first), -1});
auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size()); auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size());
auto out = myModule->forward(context); auto out = myModule->forward(context);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment