Skip to content
Snippets Groups Projects
Commit 23b452ce authored by Franck Dary's avatar Franck Dary
Browse files

New Strategy class

parent d377af89
No related branches found
No related tags found
No related merge requests found
...@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed // Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{ {
machine.getTransitionSet().getTransition("SHIFT")->apply(config); machine.getTransitionSet().getTransition("EOS b.0")->apply(config);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug) if (debug)
{ {
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
......
...@@ -7,32 +7,47 @@ class Strategy ...@@ -7,32 +7,47 @@ class Strategy
{ {
public : public :
static inline std::pair<std::string, int> endMovement{"", 0}; using Movement = std::pair<std::string, int>;
static inline Movement endMovement{"", 0};
private : private :
enum Type class Block
{ {
Incremental, private :
Sequential
};
Type type; enum EndCondition
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges; {
std::map<std::string, bool> isDone; CannotMove
std::vector<std::string> defaultCycle; };
std::vector<std::string> originalDefaultCycle;
std::string initialState{"UNDEFINED"}; std::vector<EndCondition> endConditions;
std::vector<std::tuple<std::string,std::string,std::string,int>> movements;
private :
static EndCondition str2condition(const std::string & s);
public :
Block(std::vector<std::string> endConditionsStr);
void addMovement(std::string definition);
const std::string getInitialState() const;
bool empty();
Movement getMovement(const Config & c, const std::string & transition);
bool isFinished(const Config & c, const Movement & movement);
};
private : private :
std::pair<std::string, int> getMovementSequential(const Config & c, const std::string & transition); std::string initialState{"UNDEFINED"};
std::pair<std::string, int> getMovementIncremental(const Config & c, const std::string & transition); std::vector<Block> blocks;
std::size_t currentBlock{0};
public : public :
Strategy(std::vector<std::string> lines); Strategy(std::vector<std::string> definition);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); Movement getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const; const std::string getInitialState() const;
void reset(); void reset();
}; };
......
...@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
})) }))
util::myThrow("No predictions specified"); util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end()); if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{
std::vector<std::string> strategyDefinition;
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
strategy.reset(new Strategy(restOfFile)); for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
strategyDefinition.emplace_back(lines[curLine]);
}
strategy.reset(new Strategy(strategyDefinition));
}))
util::myThrow("No Strategy specified");
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
} }
......
#include "Strategy.hpp" #include "Strategy.hpp"
Strategy::Strategy(std::vector<std::string> lines) Strategy::Strategy(std::vector<std::string> definition)
{ {
if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm) std::regex blockRegex("(?:(?:\\s|\\t)*)Block(?:(?:\\s|\\t)*):(?:(?:\\s|\\t)*)End\\{(.*)\\}(?:(?:\\s|\\t)*)");
{type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;}))
util::myThrow(fmt::format("Invalid strategy identifier '{}'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++) for (auto & line : definition)
{ if (!util::doIfNameMatch(blockRegex, line, [this](auto sm)
std::replace(lines[i].begin(), lines[i].end(), '\t', ' '); {
auto splited = util::split(lines[i], ' '); blocks.emplace_back(util::split(sm.str(1), ' '));
std::pair<std::string, std::string> key; }))
std::string value;
int movement;
if (splited.size() == 3)
{
key = std::pair<std::string,std::string>(splited[0], "");
value = splited[1];
movement = std::stoi(std::string(splited[2]));
if (defaultCycle.empty())
initialState = splited[0];
defaultCycle.emplace_back(value);
}
else if (splited.size() == 4)
{ {
key = std::pair<std::string,std::string>(splited[0], splited[2]); if (blocks.empty())
value = splited[1]; util::myThrow(fmt::format("invalid line '{}', expeced 'Block : End{}'",line,"{...}"));
movement = std::stoi(std::string(splited[3])); blocks.back().addMovement(line);
} }
else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key))
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second));
edges[key] = std::make_pair(value, movement);
isDone[key.first] = false;
}
if (edges.empty()) if (blocks.empty())
util::myThrow("Strategy is empty"); util::myThrow("empty strategy");
if (type == Type::Sequential) for (auto & block : blocks)
defaultCycle.pop_back(); if (block.empty())
std::reverse(defaultCycle.begin(), defaultCycle.end()); util::myThrow("there is an empty block");
originalDefaultCycle = defaultCycle;
} }
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) Strategy::Movement Strategy::getMovement(const Config & c, const std::string & transition)
{ {
std::string transitionPrefix(util::split(transition, ' ')[0]); auto movement = blocks[currentBlock].getMovement(c, transition);
if (c.stateIsDone())
isDone[c.getState()] = true;
if (type == Type::Sequential) if (blocks[currentBlock].isFinished(c, movement))
{ {
while (defaultCycle.size() && isDone[defaultCycle.back()]) currentBlock++;
defaultCycle.pop_back(); if (currentBlock >= blocks.size())
return endMovement;
return getMovementSequential(c, transitionPrefix); movement.first = blocks[currentBlock].getInitialState();
movement.second = -c.getWordIndex();
} }
for (unsigned int i = 0; i < defaultCycle.size(); i++) return movement;
if (isDone[defaultCycle[i]])
{
while (defaultCycle.size() != i)
defaultCycle.pop_back();
break;
}
return getMovementIncremental(c, transitionPrefix);
} }
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition) Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition)
{ {
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition)); std::string transitionPrefix(util::split(transition, ' ')[0]);
auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); auto currentState = c.getState();
std::string target;
int movement = -1;
if (foundSpecific != edges.end()) for (auto & movement : movements)
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
{ {
target = foundGeneric->second.first; auto fromState = std::get<0>(movement);
movement = foundGeneric->second.second; auto toState = std::get<1>(movement);
} auto trans = std::get<2>(movement);
auto mov = std::get<3>(movement);
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone())
return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
if (!isDone[target]) if (fromState == currentState and (trans == transitionPrefix or trans == "*"))
return {target, -c.getWordIndex()}; return std::make_pair(toState, mov);
}
util::myThrow(fmt::format("no movement found for state '{}' and transitionPrefix '{}'", currentState, transitionPrefix));
return endMovement; return endMovement;
} }
std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, const std::string & transition) const std::string Strategy::getInitialState() const
{ {
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition)); return blocks.at(0).getInitialState();
auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); }
std::string target; void Strategy::reset()
int movement = -1; {
currentBlock = 0;
}
if (foundSpecific != edges.end()) Strategy::Block::Block(std::vector<std::string> endConditionsStr)
{ {
target = foundSpecific->second.first; for (auto & cond : endConditionsStr)
movement = foundSpecific->second.second; endConditions.emplace_back(str2condition(cond));
} }
else if (foundGeneric != edges.end())
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty()) void Strategy::Block::addMovement(std::string definition)
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); {
std::regex regex("(?:(?:\\s|\\t)*)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)*)");
auto errorMessage = fmt::format("invalid definition '{}' expected fromState toState transitionNamePrefix wordIndexMovement", definition);
if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement)) if (!util::doIfNameMatch(regex, definition, [this, &errorMessage](auto sm)
target = c.getState(); {
try
{
movements.emplace_back(std::make_tuple(sm.str(1), sm.str(2), sm.str(3), std::stoi(sm.str(4))));
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' in {}", e.what(), errorMessage));}
}))
util::myThrow(errorMessage);
}
if (!isDone[target]) Strategy::Block::EndCondition Strategy::Block::str2condition(const std::string & s)
return {target, c.canMoveWordIndex(movement) ? movement : 0}; {
if (s == "cannotMove")
return EndCondition::CannotMove;
else
util::myThrow(fmt::format("invalid EndCondition '{}'", s));
if (defaultCycle.empty()) return EndCondition::CannotMove;
return endMovement; }
return {defaultCycle.back(), movement}; const std::string Strategy::Block::getInitialState() const
{
return std::get<0>(movements.at(0));
} }
const std::string Strategy::getInitialState() const bool Strategy::Block::empty()
{ {
return initialState; return movements.empty();
} }
void Strategy::reset() bool Strategy::Block::isFinished(const Config & c, const Movement & movement)
{ {
for (auto & it : isDone) for (auto & condition : endConditions)
it.second = false; if (condition == EndCondition::CannotMove)
defaultCycle = originalDefaultCycle; {
if (c.canMoveWordIndex(movement.second))
return false;
}
return true;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment