Commit 8dfbc696 authored by Franck Dary's avatar Franck Dary
Browse files

Improved forcing EOS transition, usefull for lineByLine mode

parent a2a0af1c
......@@ -38,9 +38,10 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float
baseConfig = beam[0].config;
if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
if (baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
auto eosTransition = Transition("EOS b.0");
eosTransition.apply(baseConfig);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -631,8 +631,12 @@ Action Action::setRoot(int bufferIndex)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
int rootIndex = -1;
int searchStartIndex = lineIndex;
if (searchStartIndex > 0 and config.getAsFeature(Config::idColName, lineIndex) != "1")
searchStartIndex--;
int firstSentIndex = lineIndex;
for (int i = lineIndex; true; --i)
for (int i = searchStartIndex; true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -646,6 +650,14 @@ Action Action::setRoot(int bufferIndex)
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
firstSentIndex = i;
}
for (int i = lineIndex; i >= firstSentIndex; --i)
{
if (!config.isTokenPredicted(i))
continue;
if (std::string(config.getAsFeature(Config::headColName, i)).empty())
{
rootIndex = i;
......@@ -653,20 +665,11 @@ Action Action::setRoot(int bufferIndex)
}
}
for (int i = lineIndex; true; --i)
for (int i = lineIndex; i >= firstSentIndex; --i)
{
if (!config.has(0, i, 0))
{
if (i < 0)
break;
util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize");
}
if (!config.isTokenPredicted(i))
continue;
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
if (std::string(config.getAsFeature(Config::headColName, i)).empty())
{
if (i == rootIndex)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment