Skip to content
Snippets Groups Projects
Commit b6fda44a authored by Franck Dary's avatar Franck Dary
Browse files

Forced EOS at the end of decode

parent af8ec4a7
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
......@@ -48,6 +50,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
util::myThrow("Cannot move word index !");
}
} catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
Action shift = Action::pushWordIndexOnStack();
shift.apply(config, shift);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug)
fmt::print(stderr, "Forcing EOS transition\n");
}
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
......@@ -20,6 +20,7 @@ class TransitionSet
Transition * getBestAppliableTransition(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const;
Transition * getTransition(std::size_t index);
Transition * getTransition(const std::string & name);
std::size_t size() const;
};
......
......@@ -203,7 +203,7 @@ Action Action::setRoot()
{
int rootIndex = -1;
for (int i = config.getWordIndex()-1; true; --i)
for (int i = config.getStack(0); true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -224,7 +224,7 @@ Action Action::setRoot()
}
}
for (int i = config.getWordIndex()-1; true; --i)
for (int i = config.getStack(0); true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -276,7 +276,7 @@ Action Action::updateIds()
auto apply = [](Config & config, Action & a)
{
int firstIndexOfSentence = -1;
for (int i = config.getWordIndex()-1; true; --i)
for (int i = config.getStack(0); true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -296,7 +296,7 @@ Action Action::updateIds()
if (firstIndexOfSentence < 0)
util::myThrow("could not find any token in current sentence");
for (unsigned int i = firstIndexOfSentence, currentId = 1; i < config.getWordIndex(); ++i)
for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i)
{
if (!config.isToken(i))
continue;
......
......@@ -428,7 +428,7 @@ std::size_t Config::getStack(int relativeIndex) const
bool Config::hasHistory(int relativeIndex) const
{
return relativeIndex > 0 && relativeIndex < (int)history.size();
return relativeIndex >= 0 && relativeIndex < (int)history.size();
}
bool Config::hasStack(int relativeIndex) const
......@@ -451,7 +451,7 @@ bool Config::stateIsDone() const
if (!rawInput.empty())
return rawInputOnlySeparatorsLeft();
return !has(0, wordIndex+1, 0);
return !has(0, wordIndex+1, 0) and !hasStack(0);
}
std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
......
......@@ -82,7 +82,7 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone())
return {c.getState(), movement};
return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
if (!isDone[target])
return {target, -c.getWordIndex()};
......
......@@ -85,3 +85,13 @@ Transition * TransitionSet::getTransition(std::size_t index)
return &transitions[index];
}
Transition * TransitionSet::getTransition(const std::string & name)
{
for (auto & transition : transitions)
if (transition.getName() == name)
return &transition;
return nullptr;
}
......@@ -38,6 +38,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment