Skip to content
Snippets Groups Projects
Commit 23b452ce authored by Franck Dary's avatar Franck Dary
Browse files

New Strategy class

parent d377af89
No related branches found
No related tags found
No related merge requests found
...@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed // Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{ {
machine.getTransitionSet().getTransition("SHIFT")->apply(config); machine.getTransitionSet().getTransition("EOS b.0")->apply(config);
machine.getTransitionSet().getTransition("EOS")->apply(config);
if (debug) if (debug)
{ {
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
......
...@@ -7,32 +7,47 @@ class Strategy ...@@ -7,32 +7,47 @@ class Strategy
{ {
public : public :
static inline std::pair<std::string, int> endMovement{"", 0}; using Movement = std::pair<std::string, int>;
static inline Movement endMovement{"", 0};
private : private :
enum Type class Block
{ {
Incremental, private :
Sequential
enum EndCondition
{
CannotMove
}; };
Type type; std::vector<EndCondition> endConditions;
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges; std::vector<std::tuple<std::string,std::string,std::string,int>> movements;
std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle;
std::vector<std::string> originalDefaultCycle;
std::string initialState{"UNDEFINED"};
private : private :
std::pair<std::string, int> getMovementSequential(const Config & c, const std::string & transition); static EndCondition str2condition(const std::string & s);
std::pair<std::string, int> getMovementIncremental(const Config & c, const std::string & transition);
public :
Block(std::vector<std::string> endConditionsStr);
void addMovement(std::string definition);
const std::string getInitialState() const;
bool empty();
Movement getMovement(const Config & c, const std::string & transition);
bool isFinished(const Config & c, const Movement & movement);
};
private :
std::string initialState{"UNDEFINED"};
std::vector<Block> blocks;
std::size_t currentBlock{0};
public : public :
Strategy(std::vector<std::string> lines); Strategy(std::vector<std::string> definition);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); Movement getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const; const std::string getInitialState() const;
void reset(); void reset();
}; };
......
...@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
})) }))
util::myThrow("No predictions specified"); util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end()); if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{
std::vector<std::string> strategyDefinition;
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
strategy.reset(new Strategy(restOfFile)); for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
strategyDefinition.emplace_back(lines[curLine]);
}
strategy.reset(new Strategy(strategyDefinition));
}))
util::myThrow("No Strategy specified");
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
} }
......
#include "Strategy.hpp" #include "Strategy.hpp"
Strategy::Strategy(std::vector<std::string> lines) Strategy::Strategy(std::vector<std::string> definition)
{ {
if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm) std::regex blockRegex("(?:(?:\\s|\\t)*)Block(?:(?:\\s|\\t)*):(?:(?:\\s|\\t)*)End\\{(.*)\\}(?:(?:\\s|\\t)*)");
{type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;}))
util::myThrow(fmt::format("Invalid strategy identifier '{}'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++) for (auto & line : definition)
if (!util::doIfNameMatch(blockRegex, line, [this](auto sm)
{ {
std::replace(lines[i].begin(), lines[i].end(), '\t', ' '); blocks.emplace_back(util::split(sm.str(1), ' '));
auto splited = util::split(lines[i], ' '); }))
std::pair<std::string, std::string> key;
std::string value;
int movement;
if (splited.size() == 3)
{ {
key = std::pair<std::string,std::string>(splited[0], ""); if (blocks.empty())
value = splited[1]; util::myThrow(fmt::format("invalid line '{}', expeced 'Block : End{}'",line,"{...}"));
movement = std::stoi(std::string(splited[2])); blocks.back().addMovement(line);
if (defaultCycle.empty())
initialState = splited[0];
defaultCycle.emplace_back(value);
} }
else if (splited.size() == 4)
{ if (blocks.empty())
key = std::pair<std::string,std::string>(splited[0], splited[2]); util::myThrow("empty strategy");
value = splited[1]; for (auto & block : blocks)
movement = std::stoi(std::string(splited[3])); if (block.empty())
util::myThrow("there is an empty block");
} }
else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key)) Strategy::Movement Strategy::getMovement(const Config & c, const std::string & transition)
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second)); {
edges[key] = std::make_pair(value, movement); auto movement = blocks[currentBlock].getMovement(c, transition);
isDone[key.first] = false;
if (blocks[currentBlock].isFinished(c, movement))
{
currentBlock++;
if (currentBlock >= blocks.size())
return endMovement;
movement.first = blocks[currentBlock].getInitialState();
movement.second = -c.getWordIndex();
} }
if (edges.empty()) return movement;
util::myThrow("Strategy is empty");
if (type == Type::Sequential)
defaultCycle.pop_back();
std::reverse(defaultCycle.begin(), defaultCycle.end());
originalDefaultCycle = defaultCycle;
} }
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition)
{ {
std::string transitionPrefix(util::split(transition, ' ')[0]); std::string transitionPrefix(util::split(transition, ' ')[0]);
auto currentState = c.getState();
if (c.stateIsDone()) for (auto & movement : movements)
isDone[c.getState()] = true;
if (type == Type::Sequential)
{ {
while (defaultCycle.size() && isDone[defaultCycle.back()]) auto fromState = std::get<0>(movement);
defaultCycle.pop_back(); auto toState = std::get<1>(movement);
auto trans = std::get<2>(movement);
auto mov = std::get<3>(movement);
return getMovementSequential(c, transitionPrefix); if (fromState == currentState and (trans == transitionPrefix or trans == "*"))
return std::make_pair(toState, mov);
} }
for (unsigned int i = 0; i < defaultCycle.size(); i++) util::myThrow(fmt::format("no movement found for state '{}' and transitionPrefix '{}'", currentState, transitionPrefix));
if (isDone[defaultCycle[i]]) return endMovement;
{
while (defaultCycle.size() != i)
defaultCycle.pop_back();
break;
}
return getMovementIncremental(c, transitionPrefix);
} }
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition) const std::string Strategy::getInitialState() const
{
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition));
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement = -1;
if (foundSpecific != edges.end())
{ {
target = foundSpecific->second.first; return blocks.at(0).getInitialState();
movement = foundSpecific->second.second;
} }
else if (foundGeneric != edges.end())
void Strategy::reset()
{ {
target = foundGeneric->second.first; currentBlock = 0;
movement = foundGeneric->second.second;
} }
if (target.empty()) Strategy::Block::Block(std::vector<std::string> endConditionsStr)
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); {
for (auto & cond : endConditionsStr)
if (!c.stateIsDone()) endConditions.emplace_back(str2condition(cond));
return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
if (!isDone[target])
return {target, -c.getWordIndex()};
return endMovement;
} }
std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, const std::string & transition) void Strategy::Block::addMovement(std::string definition)
{ {
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition)); std::regex regex("(?:(?:\\s|\\t)*)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)*)");
auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); auto errorMessage = fmt::format("invalid definition '{}' expected fromState toState transitionNamePrefix wordIndexMovement", definition);
std::string target; if (!util::doIfNameMatch(regex, definition, [this, &errorMessage](auto sm)
int movement = -1;
if (foundSpecific != edges.end())
{ {
target = foundSpecific->second.first; try
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
{ {
target = foundGeneric->second.first; movements.emplace_back(std::make_tuple(sm.str(1), sm.str(2), sm.str(3), std::stoi(sm.str(4))));
movement = foundGeneric->second.second; } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' in {}", e.what(), errorMessage));}
}))
util::myThrow(errorMessage);
} }
if (target.empty()) Strategy::Block::EndCondition Strategy::Block::str2condition(const std::string & s)
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); {
if (s == "cannotMove")
if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement)) return EndCondition::CannotMove;
target = c.getState(); else
util::myThrow(fmt::format("invalid EndCondition '{}'", s));
if (!isDone[target])
return {target, c.canMoveWordIndex(movement) ? movement : 0};
if (defaultCycle.empty()) return EndCondition::CannotMove;
return endMovement; }
return {defaultCycle.back(), movement}; const std::string Strategy::Block::getInitialState() const
{
return std::get<0>(movements.at(0));
} }
const std::string Strategy::getInitialState() const bool Strategy::Block::empty()
{ {
return initialState; return movements.empty();
} }
void Strategy::reset() bool Strategy::Block::isFinished(const Config & c, const Movement & movement)
{
for (auto & condition : endConditions)
if (condition == EndCondition::CannotMove)
{ {
for (auto & it : isDone) if (c.canMoveWordIndex(movement.second))
it.second = false; return false;
defaultCycle = originalDefaultCycle; }
return true;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment