Skip to content
Snippets Groups Projects
Commit 23b452ce authored by Franck Dary's avatar Franck Dary
Browse files

New Strategy class

parent d377af89
No related branches found
No related tags found
No related merge requests found
......@@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
} catch(std::exception & e) {util::myThrow(e.what());}
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet().getTransition("SHIFT")->apply(config);
machine.getTransitionSet().getTransition("EOS")->apply(config);
machine.getTransitionSet().getTransition("EOS b.0")->apply(config);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -7,32 +7,47 @@ class Strategy
{
public :
static inline std::pair<std::string, int> endMovement{"", 0};
using Movement = std::pair<std::string, int>;
static inline Movement endMovement{"", 0};
private :
enum Type
class Block
{
Incremental,
Sequential
};
private :
Type type;
std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges;
std::map<std::string, bool> isDone;
std::vector<std::string> defaultCycle;
std::vector<std::string> originalDefaultCycle;
std::string initialState{"UNDEFINED"};
enum EndCondition
{
CannotMove
};
std::vector<EndCondition> endConditions;
std::vector<std::tuple<std::string,std::string,std::string,int>> movements;
private :
static EndCondition str2condition(const std::string & s);
public :
Block(std::vector<std::string> endConditionsStr);
void addMovement(std::string definition);
const std::string getInitialState() const;
bool empty();
Movement getMovement(const Config & c, const std::string & transition);
bool isFinished(const Config & c, const Movement & movement);
};
private :
std::pair<std::string, int> getMovementSequential(const Config & c, const std::string & transition);
std::pair<std::string, int> getMovementIncremental(const Config & c, const std::string & transition);
std::string initialState{"UNDEFINED"};
std::vector<Block> blocks;
std::size_t currentBlock{0};
public :
Strategy(std::vector<std::string> lines);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
Strategy(std::vector<std::string> definition);
Movement getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const;
void reset();
};
......
......@@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end());
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{
std::vector<std::string> strategyDefinition;
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
strategy.reset(new Strategy(restOfFile));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
break;
strategyDefinition.emplace_back(lines[curLine]);
}
strategy.reset(new Strategy(strategyDefinition));
}))
util::myThrow("No Strategy specified");
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
......
#include "Strategy.hpp"
Strategy::Strategy(std::vector<std::string> lines)
Strategy::Strategy(std::vector<std::string> definition)
{
if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm)
{type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;}))
util::myThrow(fmt::format("Invalid strategy identifier '{}'", lines[0]));
std::regex blockRegex("(?:(?:\\s|\\t)*)Block(?:(?:\\s|\\t)*):(?:(?:\\s|\\t)*)End\\{(.*)\\}(?:(?:\\s|\\t)*)");
for (unsigned int i = 1; i < lines.size(); i++)
{
std::replace(lines[i].begin(), lines[i].end(), '\t', ' ');
auto splited = util::split(lines[i], ' ');
std::pair<std::string, std::string> key;
std::string value;
int movement;
if (splited.size() == 3)
{
key = std::pair<std::string,std::string>(splited[0], "");
value = splited[1];
movement = std::stoi(std::string(splited[2]));
if (defaultCycle.empty())
initialState = splited[0];
defaultCycle.emplace_back(value);
}
else if (splited.size() == 4)
for (auto & line : definition)
if (!util::doIfNameMatch(blockRegex, line, [this](auto sm)
{
blocks.emplace_back(util::split(sm.str(1), ' '));
}))
{
key = std::pair<std::string,std::string>(splited[0], splited[2]);
value = splited[1];
movement = std::stoi(std::string(splited[3]));
if (blocks.empty())
util::myThrow(fmt::format("invalid line '{}', expeced 'Block : End{}'",line,"{...}"));
blocks.back().addMovement(line);
}
else
util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i]));
if (edges.count(key))
util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second));
edges[key] = std::make_pair(value, movement);
isDone[key.first] = false;
}
if (edges.empty())
util::myThrow("Strategy is empty");
if (type == Type::Sequential)
defaultCycle.pop_back();
std::reverse(defaultCycle.begin(), defaultCycle.end());
originalDefaultCycle = defaultCycle;
if (blocks.empty())
util::myThrow("empty strategy");
for (auto & block : blocks)
if (block.empty())
util::myThrow("there is an empty block");
}
std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition)
Strategy::Movement Strategy::getMovement(const Config & c, const std::string & transition)
{
std::string transitionPrefix(util::split(transition, ' ')[0]);
if (c.stateIsDone())
isDone[c.getState()] = true;
auto movement = blocks[currentBlock].getMovement(c, transition);
if (type == Type::Sequential)
if (blocks[currentBlock].isFinished(c, movement))
{
while (defaultCycle.size() && isDone[defaultCycle.back()])
defaultCycle.pop_back();
return getMovementSequential(c, transitionPrefix);
currentBlock++;
if (currentBlock >= blocks.size())
return endMovement;
movement.first = blocks[currentBlock].getInitialState();
movement.second = -c.getWordIndex();
}
for (unsigned int i = 0; i < defaultCycle.size(); i++)
if (isDone[defaultCycle[i]])
{
while (defaultCycle.size() != i)
defaultCycle.pop_back();
break;
}
return getMovementIncremental(c, transitionPrefix);
return movement;
}
std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition)
Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition)
{
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition));
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
std::string target;
int movement = -1;
std::string transitionPrefix(util::split(transition, ' ')[0]);
auto currentState = c.getState();
if (foundSpecific != edges.end())
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
for (auto & movement : movements)
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
if (!c.stateIsDone())
return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0};
auto fromState = std::get<0>(movement);
auto toState = std::get<1>(movement);
auto trans = std::get<2>(movement);
auto mov = std::get<3>(movement);
if (!isDone[target])
return {target, -c.getWordIndex()};
if (fromState == currentState and (trans == transitionPrefix or trans == "*"))
return std::make_pair(toState, mov);
}
util::myThrow(fmt::format("no movement found for state '{}' and transitionPrefix '{}'", currentState, transitionPrefix));
return endMovement;
}
std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, const std::string & transition)
const std::string Strategy::getInitialState() const
{
auto foundSpecific = edges.find(std::make_pair(c.getState(), transition));
auto foundGeneric = edges.find(std::make_pair(c.getState(), ""));
return blocks.at(0).getInitialState();
}
std::string target;
int movement = -1;
void Strategy::reset()
{
currentBlock = 0;
}
if (foundSpecific != edges.end())
{
target = foundSpecific->second.first;
movement = foundSpecific->second.second;
}
else if (foundGeneric != edges.end())
{
target = foundGeneric->second.first;
movement = foundGeneric->second.second;
}
Strategy::Block::Block(std::vector<std::string> endConditionsStr)
{
for (auto & cond : endConditionsStr)
endConditions.emplace_back(str2condition(cond));
}
if (target.empty())
util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition));
void Strategy::Block::addMovement(std::string definition)
{
std::regex regex("(?:(?:\\s|\\t)*)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)*)");
auto errorMessage = fmt::format("invalid definition '{}' expected fromState toState transitionNamePrefix wordIndexMovement", definition);
if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement))
target = c.getState();
if (!util::doIfNameMatch(regex, definition, [this, &errorMessage](auto sm)
{
try
{
movements.emplace_back(std::make_tuple(sm.str(1), sm.str(2), sm.str(3), std::stoi(sm.str(4))));
} catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' in {}", e.what(), errorMessage));}
}))
util::myThrow(errorMessage);
}
if (!isDone[target])
return {target, c.canMoveWordIndex(movement) ? movement : 0};
Strategy::Block::EndCondition Strategy::Block::str2condition(const std::string & s)
{
if (s == "cannotMove")
return EndCondition::CannotMove;
else
util::myThrow(fmt::format("invalid EndCondition '{}'", s));
if (defaultCycle.empty())
return endMovement;
return EndCondition::CannotMove;
}
return {defaultCycle.back(), movement};
const std::string Strategy::Block::getInitialState() const
{
return std::get<0>(movements.at(0));
}
const std::string Strategy::getInitialState() const
bool Strategy::Block::empty()
{
return initialState;
return movements.empty();
}
void Strategy::reset()
bool Strategy::Block::isFinished(const Config & c, const Movement & movement)
{
for (auto & it : isDone)
it.second = false;
defaultCycle = originalDefaultCycle;
for (auto & condition : endConditions)
if (condition == EndCondition::CannotMove)
{
if (c.canMoveWordIndex(movement.second))
return false;
}
return true;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment