Skip to content
Snippets Groups Projects
Commit 23b452ce authored by Franck Dary's avatar Franck Dary
Browse files

New Strategy class

parent d377af89
No related branches found
No related tags found
No related merge requests found
......@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
} catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet().getTransition("SHIFT")->apply(config);
machine.getTransitionSet().getTransition("EOS")->apply(config);
machine.getTransitionSet().getTransition("EOS b.0")->apply(config);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -7,32 +7,47 @@ class Strategy
{
public :
static inline std::pair<std::string, int> endMovement{"", 0};
using Movement = std::pair<std::string, int>;
static inline Movement endMovement{"", 0};
private :
enum Type
class Block
{
Incremental,
Sequential
private :
enum EndCondition
{
CannotMove
};
Type type;
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle;
std::vector<std::string> originalDefaultCycle;
std::string initialState{"UNDEFINED"};
std::vector<EndCondition> endConditions;
std::vector<std::tuple<std::string,std::string,std::string,int>> movements;
private :
std::pair<std::string, int> getMovementSequential(const Config & c, const std::string & transition);
std::pair<std::string, int> getMovementIncremental(const Config & c, const std::string & transition);
static EndCondition str2condition(const std::string & s);
public :
Block(std::vector<std::string> endConditionsStr);
void addMovement(std::string definition);
const std::string getInitialState() const;
bool empty();
Movement getMovement(const Config & c, const std::string & transition);
bool isFinished(const Config & c, const Movement & movement);
};
private :
std::string initialState{"UNDEFINED"};
std::vector<Block> blocks;
std::size_t currentBlock{0};
public :
Strategy(std::vector<std::string> lines);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
Strategy(std::vector<std::string> definition);
Movement getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const;
void reset();
};
......
......@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end());
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{
std::vector<std::string> strategyDefinition;
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
strategy.reset(new Strategy(restOfFile));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
strategyDefinition.emplace_back(lines[curLine]);
}
strategy.reset(new Strategy(strategyDefinition));
}))
util::myThrow("No Strategy specified");
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
......
#include "Strategy.hpp"
Strategy::Strategy(std::vector<std::string> lines)
Strategy::Strategy(std::vector<std::string> definition)
{
if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm)
{type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;}))
util::myThrow(fmt::format("Invalid strategy identifier '{}'", lines[0]));
std::regex blockRegex("(?:(?:\\s|\\t)*)Block(?:(?:\\s|\\t)*):(?:(?:\\s|\\t)*)End\\{(.*)\\}(?:(?:\\s|\\t)*)");
for (unsigned int i = 1; i < lines.size(); i++)
for (auto & line : definition)
if (!util::doIfNameMatch(blockRegex, line, [this](auto sm)
{
std::replace(lines[i].begin(), lines[i].end(), '\t', ' ');
auto splited = util::split(lines[i], ' ');
std::pair<std::string, std::string> key;
std::string value;
int movement;
if (splited.size() == 3)
blocks.emplace_back(util::split(sm.str(1), ' '));
}))
{
key = std::pair<std::string,std::string>(splited[0], "");
value = splited[1];
movement = std::stoi(std::string(splited[2]));
if (defaultCycle.empty())
initialState = splited[0];
defaultCycle.emplace_back(value);
if (blocks.empty())
util::myThrow(fmt::format("invalid line '{}', expeced 'Block : End{}'",line,"{...}"));
blocks.back().addMovement(line);
}
else if (splited.size() == 4)
{
key = std::pair<std::string,std::string>(splited[0], splited[2]);
value = splited[1];
movement = std::stoi(std::string(splited[3]));
if (blocks.empty())
util::myThrow("empty strategy");
for (auto & block : blocks)
if (block.empty())
util::myThrow("there is an empty block");
}
else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key))
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second));
edges[key] = std::make_pair(value, movement);
isDone[key.first] = false;
Strategy::Movement Strategy::getMovement(const Config & c, const std::string & transition)
{
auto movement = blocks[currentBlock].getMovement(c, transition);
if (blocks[currentBlock].isFinished(c, movement))
{
currentBlock++;
if (currentBlock >= blocks.size())
return endMovement;
movement.first = blocks[currentBlock].getInitialState();
movement.second = -c.getWordIndex();
}
if (edges.empty())
util::myThrow("Strategy is empty");
if (type == Type::Sequential)
defaultCycle.pop_back();
std::reverse(defaultCycle.begin(), defaultCycle.end());
originalDefaultCycle = defaultCycle;
return movement;
}
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition)
{
std::string transitionPrefix(util::split(transition, ' ')[0]);
auto currentState = c.getState();
if (c.stateIsDone())
isDone[c.getState()] = true;
if (type == Type::Sequential)
for (auto & movement : movements)
{
while (defaultCycle.size() && isDone[defaultCycle.back()])
defaultCycle.pop_back();
auto fromState = std::get<0>(movement);
auto toState = std::get<1>(movement);
auto trans = std::get<2>(movement);
auto mov = std::get<3>(movement);
return getMovementSequential(c, transitionPrefix);
if (fromState == currentState and (trans == transitionPrefix or trans == "*"))
return std::make_pair(toState, mov);
}
for (unsigned int i = 0; i < defaultCycle.size(); i++)
if (isDone[defaultCycle[i]])
{
while (defaultCycle.size() != i)
defaultCycle.pop_back();
break;
}
return getMovementIncremental(c, transitionPrefix);
util::myThrow(fmt::format("no movement found for state '{}' and transitionPrefix '{}'", currentState, transitionPrefix));
return endMovement;
}
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition)
{
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition));
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement = -1;
if (foundSpecific != edges.end())
const std::string Strategy::getInitialState() const
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
return blocks.at(0).getInitialState();
}
else if (foundGeneric != edges.end())
void Strategy::reset()
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
currentBlock = 0;
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone())
return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
if (!isDone[target])
return {target, -c.getWordIndex()};
return endMovement;
Strategy::Block::Block(std::vector<std::string> endConditionsStr)
{
for (auto & cond : endConditionsStr)
endConditions.emplace_back(str2condition(cond));
}
std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, const std::string & transition)
void Strategy::Block::addMovement(std::string definition)
{
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition));
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::regex regex("(?:(?:\\s|\\t)*)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)*)");
auto errorMessage = fmt::format("invalid definition '{}' expected fromState toState transitionNamePrefix wordIndexMovement", definition);
std::string target;
int movement = -1;
if (foundSpecific != edges.end())
if (!util::doIfNameMatch(regex, definition, [this, &errorMessage](auto sm)
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
try
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
movements.emplace_back(std::make_tuple(sm.str(1), sm.str(2), sm.str(3), std::stoi(sm.str(4))));
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' in {}", e.what(), errorMessage));}
}))
util::myThrow(errorMessage);
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement))
target = c.getState();
if (!isDone[target])
return {target, c.canMoveWordIndex(movement) ? movement : 0};
Strategy::Block::EndCondition Strategy::Block::str2condition(const std::string & s)
{
if (s == "cannotMove")
return EndCondition::CannotMove;
else
util::myThrow(fmt::format("invalid EndCondition '{}'", s));
if (defaultCycle.empty())
return endMovement;
return EndCondition::CannotMove;
}
return {defaultCycle.back(), movement};
const std::string Strategy::Block::getInitialState() const
{
return std::get<0>(movements.at(0));
}
const std::string Strategy::getInitialState() const
bool Strategy::Block::empty()
{
return initialState;
return movements.empty();
}
void Strategy::reset()
bool Strategy::Block::isFinished(const Config & c, const Movement & movement)
{
for (auto & condition : endConditions)
if (condition == EndCondition::CannotMove)
{
for (auto & it : isDone)
it.second = false;
defaultCycle = originalDefaultCycle;
if (c.canMoveWordIndex(movement.second))
return false;
}
return true;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment