Commit 61e77a5a authored by Franck Dary's avatar Franck Dary
Browse files

Changed the encoding of features in certain modules

parent 5e1f4062
......@@ -148,10 +148,19 @@ void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & con
if (col == Config::idColName)
{
std::string value;
if (config.isMultiwordPredicted(index))
if (config.getAsFeature(Config::idColName, index).empty())
value = "empty";
else if (config.isMultiwordPredicted(index))
value = "multiword";
else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1)
value = "part";
else if (config.isTokenPredicted(index))
value = "token";
else
{
config.printForDebug(stderr);
util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index));
}
dictIndex = dict.getIndexOrInsert(value, col);
}
else
......
......@@ -156,11 +156,19 @@ void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config &
if (col == Config::idColName)
{
std::string value;
if (config.isMultiwordPredicted(index))
if (config.getAsFeature(Config::idColName, index).empty())
value = "empty";
else if (config.isMultiwordPredicted(index))
value = "multiword";
else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1)
value = "part";
else if (config.isTokenPredicted(index))
value = "token";
dictIndex = dict.getIndexOrInsert(value, col);
else
{
config.printForDebug(stderr);
util::myThrow(fmt::format("{} col at index {} not token nor multiword", Config::idColName, index));
}
}
else
{
......
......@@ -94,16 +94,16 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
if (config.hasStack(index))
focusedIndexes.emplace_back(config.getStack(index));
else
focusedIndexes.emplace_back(-1);
focusedIndexes.emplace_back(-2);
int insertIndex = 0;
for (auto index : focusedIndexes)
{
if (index == -1)
if (index == -1 or index == -2)
{
for (int i = 0; i < maxNbElements; i++)
{
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column);
context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(index == -1 ? Dict::oobValueStr : Dict::nullValueStr, column);
insertIndex++;
}
continue;
......@@ -113,13 +113,11 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
if (column == "FORM")
{
auto asUtf8 = util::splitAsUtf8(func(std::string(config.getAsFeature(column, index))));
//TODO don't use nullValueStr here
for (int i = 0; i < maxNbElements; i++)
if (i < (int)asUtf8.size())
elements.emplace_back(fmt::format("{}", asUtf8[i]));
else
elements.emplace_back(Dict::nullValueStr);
elements.emplace_back("<padding>");
}
else if (column == "FEATS")
{
......@@ -129,16 +127,18 @@ void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config
if (i < (int)splited.size())
elements.emplace_back(splited[i]);
else
elements.emplace_back(Dict::nullValueStr);
elements.emplace_back("<padding>");
}
else if (column == "ID")
else if (column == Config::idColName)
{
if (config.isTokenPredicted(index))
elements.emplace_back("TOKEN");
if (config.getAsFeature(Config::idColName, index).empty())
elements.emplace_back("empty");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("MULTIWORD");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("EMPTYNODE");
elements.emplace_back("multiword");
else if (config.getAsFeature(Config::isMultiColName, index) == Config::EOSSymbol1)
elements.emplace_back("part");
else if (config.isTokenPredicted(index))
elements.emplace_back("token");
}
else if (column == "EOS")
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment