Skip to content
Snippets Groups Projects
Commit b6fda44a authored by Franck Dary's avatar Franck Dary
Browse files

Forced EOS at the end of decode

parent af8ec4a7
No related branches found
No related tags found
No related merge requests found
...@@ -40,6 +40,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) ...@@ -40,6 +40,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement) if (movement == Strategy::endMovement)
break; break;
...@@ -48,6 +50,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) ...@@ -48,6 +50,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
util::myThrow("Cannot move word index !"); util::myThrow("Cannot move word index !");
} }
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
Action shift = Action::pushWordIndexOnStack();
shift.apply(config, shift);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug)
fmt::print(stderr, "Forcing EOS transition\n");
}
} }
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
...@@ -20,6 +20,7 @@ class TransitionSet ...@@ -20,6 +20,7 @@ class TransitionSet
Transition * getBestAppliableTransition(const Config & c); Transition * getBestAppliableTransition(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const; std::size_t getTransitionIndex(const Transition * transition) const;
Transition * getTransition(std::size_t index); Transition * getTransition(std::size_t index);
Transition * getTransition(const std::string & name);
std::size_t size() const; std::size_t size() const;
}; };
......
...@@ -203,7 +203,7 @@ Action Action::setRoot() ...@@ -203,7 +203,7 @@ Action Action::setRoot()
{ {
int rootIndex = -1; int rootIndex = -1;
for (int i = config.getWordIndex()-1; true; --i) for (int i = config.getStack(0); true; --i)
{ {
if (!config.has(0, i, 0)) if (!config.has(0, i, 0))
{ {
...@@ -224,7 +224,7 @@ Action Action::setRoot() ...@@ -224,7 +224,7 @@ Action Action::setRoot()
} }
} }
for (int i = config.getWordIndex()-1; true; --i) for (int i = config.getStack(0); true; --i)
{ {
if (!config.has(0, i, 0)) if (!config.has(0, i, 0))
{ {
...@@ -276,7 +276,7 @@ Action Action::updateIds() ...@@ -276,7 +276,7 @@ Action Action::updateIds()
auto apply = [](Config & config, Action & a) auto apply = [](Config & config, Action & a)
{ {
int firstIndexOfSentence = -1; int firstIndexOfSentence = -1;
for (int i = config.getWordIndex()-1; true; --i) for (int i = config.getStack(0); true; --i)
{ {
if (!config.has(0, i, 0)) if (!config.has(0, i, 0))
{ {
...@@ -296,7 +296,7 @@ Action Action::updateIds() ...@@ -296,7 +296,7 @@ Action Action::updateIds()
if (firstIndexOfSentence < 0) if (firstIndexOfSentence < 0)
util::myThrow("could not find any token in current sentence"); util::myThrow("could not find any token in current sentence");
for (unsigned int i = firstIndexOfSentence, currentId = 1; i < config.getWordIndex(); ++i) for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i)
{ {
if (!config.isToken(i)) if (!config.isToken(i))
continue; continue;
......
...@@ -428,7 +428,7 @@ std::size_t Config::getStack(int relativeIndex) const ...@@ -428,7 +428,7 @@ std::size_t Config::getStack(int relativeIndex) const
bool Config::hasHistory(int relativeIndex) const bool Config::hasHistory(int relativeIndex) const
{ {
return relativeIndex > 0 && relativeIndex < (int)history.size(); return relativeIndex >= 0 && relativeIndex < (int)history.size();
} }
bool Config::hasStack(int relativeIndex) const bool Config::hasStack(int relativeIndex) const
...@@ -451,7 +451,7 @@ bool Config::stateIsDone() const ...@@ -451,7 +451,7 @@ bool Config::stateIsDone() const
if (!rawInput.empty()) if (!rawInput.empty())
return rawInputOnlySeparatorsLeft(); return rawInputOnlySeparatorsLeft();
return !has(0, wordIndex+1, 0); return !has(0, wordIndex+1, 0) and !hasStack(0);
} }
std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
......
...@@ -82,7 +82,7 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co ...@@ -82,7 +82,7 @@ std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, co
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone()) if (!c.stateIsDone())
return {c.getState(), movement}; return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
if (!isDone[target]) if (!isDone[target])
return {target, -c.getWordIndex()}; return {target, -c.getWordIndex()};
......
...@@ -85,3 +85,13 @@ Transition * TransitionSet::getTransition(std::size_t index) ...@@ -85,3 +85,13 @@ Transition * TransitionSet::getTransition(std::size_t index)
return &transitions[index]; return &transitions[index];
} }
Transition * TransitionSet::getTransition(const std::string & name)
{
for (auto & transition : transitions)
if (transition.getName() == name)
return &transition;
return nullptr;
}
...@@ -38,6 +38,8 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -38,6 +38,8 @@ void Trainer::createDataset(SubConfig & config, bool debug)
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement) if (movement == Strategy::endMovement)
break; break;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment