Skip to content
Snippets Groups Projects
Commit 8dfbc696 authored by Franck Dary's avatar Franck Dary
Browse files

Improved forcing EOS transition, usefull for lineByLine mode

parent a2a0af1c
Branches
No related tags found
No related merge requests found
......@@ -38,9 +38,10 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
baseConfig = beam[0].config;
if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
if (baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
auto eosTransition = Transition("EOS b.0");
eosTransition.apply(baseConfig);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -631,8 +631,12 @@ Action Action::setRoot(int bufferIndex)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
int rootIndex = -1;
int searchStartIndex = lineIndex;
if (searchStartIndex > 0 and config.getAsFeature(Config::idColName, lineIndex) != "1")
searchStartIndex--;
int firstSentIndex = lineIndex;
for (int i = lineIndex; true; --i)
for (int i = searchStartIndex; true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -646,6 +650,14 @@ Action Action::setRoot(int bufferIndex)
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
firstSentIndex = i;
}
for (int i = lineIndex; i >= firstSentIndex; --i)
{
if (!config.isTokenPredicted(i))
continue;
if (std::string(config.getAsFeature(Config::headColName, i)).empty())
{
rootIndex = i;
......@@ -653,20 +665,11 @@ Action Action::setRoot(int bufferIndex)
}
}
for (int i = lineIndex; true; --i)
for (int i = lineIndex; i >= firstSentIndex; --i)
{
if (!config.has(0, i, 0))
{
if (i < 0)
break;
util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize");
}
if (!config.isTokenPredicted(i))
continue;
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
if (std::string(config.getAsFeature(Config::headColName, i)).empty())
{
if (i == rootIndex)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment