diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f91a10a589fc6217e68237612ff28b42d8805b67..1e24214877537af62f25e99fdbf2742315fae15c 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -100,10 +100,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool } catch(std::exception & e) {util::myThrow(e.what());} // Force EOS when needed - if (machine.getTransitionSet().getTransition("EOS") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) + if (machine.getTransitionSet().getTransition("EOS b.0") and config.getLastNotEmptyHypConst(Config::EOSColName, config.getWordIndex()) != Config::EOSSymbol1) { - machine.getTransitionSet().getTransition("SHIFT")->apply(config); - machine.getTransitionSet().getTransition("EOS")->apply(config); + machine.getTransitionSet().getTransition("EOS b.0")->apply(config); if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 706d0faa31c25de3e0253021d2914f09accc07a9..79fc26cbc917d59ce17f93024f33dc60652e6b88 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -7,32 +7,47 @@ class Strategy { public : - static inline std::pair<std::string, int> endMovement{"", 0}; + using Movement = std::pair<std::string, int>; + static inline Movement endMovement{"", 0}; private : - enum Type + class Block { - Incremental, - Sequential - }; + private : - Type type; - std::map<std::pair<std::string, std::string>, std::pair<std::string, int>> edges; - std::map<std::string, bool> isDone; - std::vector<std::string> defaultCycle; - std::vector<std::string> originalDefaultCycle; - std::string initialState{"UNDEFINED"}; + enum EndCondition + { + CannotMove + }; + + std::vector<EndCondition> endConditions; + std::vector<std::tuple<std::string,std::string,std::string,int>> movements; + + private : + + static EndCondition str2condition(const std::string & s); + + public : + + Block(std::vector<std::string> endConditionsStr); + void addMovement(std::string definition); + const std::string getInitialState() const; + bool empty(); + Movement getMovement(const Config & c, const std::string & transition); + bool isFinished(const Config & c, const Movement & movement); + }; private : - std::pair<std::string, int> getMovementSequential(const Config & c, const std::string & transition); - std::pair<std::string, int> getMovementIncremental(const Config & c, const std::string & transition); + std::string initialState{"UNDEFINED"}; + std::vector<Block> blocks; + std::size_t currentBlock{0}; public : - Strategy(std::vector<std::string> lines); - std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); + Strategy(std::vector<std::string> definition); + Movement getMovement(const Config & c, const std::string & transition); const std::string getInitialState() const; void reset(); }; diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index e172acbc9806fc550a7fc41b362f55cc76bd23d2..f402961779285500730d0a02ecc08d37ee6912b0 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -104,9 +104,21 @@ void ReadingMachine::readFromFile(std::filesystem::path path) })) util::myThrow("No predictions specified"); - auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end()); + if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) + { + std::vector<std::string> strategyDefinition; + if (lines[curLine] != "{") + util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); - strategy.reset(new Strategy(restOfFile)); + for (curLine++; curLine < lines.size(); curLine++) + { + if (lines[curLine] == "}") + break; + strategyDefinition.emplace_back(lines[curLine]); + } + strategy.reset(new Strategy(strategyDefinition)); + })) + util::myThrow("No Strategy specified"); } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index 7a82a8b8730aa77ed0e8bcdd01ac52dd9fc8d0bf..af65052e489dca55c0fa08b5f38ccc65019c9b11 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -1,151 +1,123 @@ #include "Strategy.hpp" -Strategy::Strategy(std::vector<std::string> lines) +Strategy::Strategy(std::vector<std::string> definition) { - if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm) - {type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;})) - util::myThrow(fmt::format("Invalid strategy identifier '{}'", lines[0])); + std::regex blockRegex("(?:(?:\\s|\\t)*)Block(?:(?:\\s|\\t)*):(?:(?:\\s|\\t)*)End\\{(.*)\\}(?:(?:\\s|\\t)*)"); - for (unsigned int i = 1; i < lines.size(); i++) - { - std::replace(lines[i].begin(), lines[i].end(), '\t', ' '); - auto splited = util::split(lines[i], ' '); - std::pair<std::string, std::string> key; - std::string value; - int movement; - - if (splited.size() == 3) - { - key = std::pair<std::string,std::string>(splited[0], ""); - value = splited[1]; - movement = std::stoi(std::string(splited[2])); - if (defaultCycle.empty()) - initialState = splited[0]; - defaultCycle.emplace_back(value); - } - else if (splited.size() == 4) + for (auto & line : definition) + if (!util::doIfNameMatch(blockRegex, line, [this](auto sm) + { + blocks.emplace_back(util::split(sm.str(1), ' ')); + })) { - key = std::pair<std::string,std::string>(splited[0], splited[2]); - value = splited[1]; - movement = std::stoi(std::string(splited[3])); + if (blocks.empty()) + util::myThrow(fmt::format("invalid line '{}', expeced 'Block : End{}'",line,"{...}")); + blocks.back().addMovement(line); } - else - util::myThrow(fmt::format("Invalid strategy line '{}'", lines[i])); - - if (edges.count(key)) - util::myThrow(fmt::format("Edge {} {} defined twice", key.first, key.second)); - edges[key] = std::make_pair(value, movement); - isDone[key.first] = false; - } - if (edges.empty()) - util::myThrow("Strategy is empty"); - if (type == Type::Sequential) - defaultCycle.pop_back(); - std::reverse(defaultCycle.begin(), defaultCycle.end()); - originalDefaultCycle = defaultCycle; + if (blocks.empty()) + util::myThrow("empty strategy"); + for (auto & block : blocks) + if (block.empty()) + util::myThrow("there is an empty block"); } -std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::string & transition) +Strategy::Movement Strategy::getMovement(const Config & c, const std::string & transition) { - std::string transitionPrefix(util::split(transition, ' ')[0]); - - if (c.stateIsDone()) - isDone[c.getState()] = true; + auto movement = blocks[currentBlock].getMovement(c, transition); - if (type == Type::Sequential) + if (blocks[currentBlock].isFinished(c, movement)) { - while (defaultCycle.size() && isDone[defaultCycle.back()]) - defaultCycle.pop_back(); - - return getMovementSequential(c, transitionPrefix); + currentBlock++; + if (currentBlock >= blocks.size()) + return endMovement; + movement.first = blocks[currentBlock].getInitialState(); + movement.second = -c.getWordIndex(); } - for (unsigned int i = 0; i < defaultCycle.size(); i++) - if (isDone[defaultCycle[i]]) - { - while (defaultCycle.size() != i) - defaultCycle.pop_back(); - break; - } - - return getMovementIncremental(c, transitionPrefix); + return movement; } -std::pair<std::string, int> Strategy::getMovementSequential(const Config & c, const std::string & transition) +Strategy::Movement Strategy::Block::getMovement(const Config & c, const std::string & transition) { - auto foundSpecific = edges.find(std::make_pair(c.getState(), transition)); - auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); - - std::string target; - int movement = -1; + std::string transitionPrefix(util::split(transition, ' ')[0]); + auto currentState = c.getState(); - if (foundSpecific != edges.end()) - { - target = foundSpecific->second.first; - movement = foundSpecific->second.second; - } - else if (foundGeneric != edges.end()) + for (auto & movement : movements) { - target = foundGeneric->second.first; - movement = foundGeneric->second.second; - } - - if (target.empty()) - util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); - - if (!c.stateIsDone()) - return {c.getState(), c.canMoveWordIndex(movement) ? movement : 0}; + auto fromState = std::get<0>(movement); + auto toState = std::get<1>(movement); + auto trans = std::get<2>(movement); + auto mov = std::get<3>(movement); - if (!isDone[target]) - return {target, -c.getWordIndex()}; + if (fromState == currentState and (trans == transitionPrefix or trans == "*")) + return std::make_pair(toState, mov); + } + util::myThrow(fmt::format("no movement found for state '{}' and transitionPrefix '{}'", currentState, transitionPrefix)); return endMovement; } -std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, const std::string & transition) +const std::string Strategy::getInitialState() const { - auto foundSpecific = edges.find(std::make_pair(c.getState(), transition)); - auto foundGeneric = edges.find(std::make_pair(c.getState(), "")); + return blocks.at(0).getInitialState(); +} - std::string target; - int movement = -1; +void Strategy::reset() +{ + currentBlock = 0; +} - if (foundSpecific != edges.end()) - { - target = foundSpecific->second.first; - movement = foundSpecific->second.second; - } - else if (foundGeneric != edges.end()) - { - target = foundGeneric->second.first; - movement = foundGeneric->second.second; - } +Strategy::Block::Block(std::vector<std::string> endConditionsStr) +{ + for (auto & cond : endConditionsStr) + endConditions.emplace_back(str2condition(cond)); +} - if (target.empty()) - util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); +void Strategy::Block::addMovement(std::string definition) +{ + std::regex regex("(?:(?:\\s|\\t)*)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)+)(\\S+)(?:(?:\\s|\\t)*)"); + auto errorMessage = fmt::format("invalid definition '{}' expected fromState toState transitionNamePrefix wordIndexMovement", definition); - if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement)) - target = c.getState(); + if (!util::doIfNameMatch(regex, definition, [this, &errorMessage](auto sm) + { + try + { + movements.emplace_back(std::make_tuple(sm.str(1), sm.str(2), sm.str(3), std::stoi(sm.str(4)))); + } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' in {}", e.what(), errorMessage));} + })) + util::myThrow(errorMessage); +} - if (!isDone[target]) - return {target, c.canMoveWordIndex(movement) ? movement : 0}; +Strategy::Block::EndCondition Strategy::Block::str2condition(const std::string & s) +{ + if (s == "cannotMove") + return EndCondition::CannotMove; + else + util::myThrow(fmt::format("invalid EndCondition '{}'", s)); - if (defaultCycle.empty()) - return endMovement; + return EndCondition::CannotMove; +} - return {defaultCycle.back(), movement}; +const std::string Strategy::Block::getInitialState() const +{ + return std::get<0>(movements.at(0)); } -const std::string Strategy::getInitialState() const +bool Strategy::Block::empty() { - return initialState; + return movements.empty(); } -void Strategy::reset() +bool Strategy::Block::isFinished(const Config & c, const Movement & movement) { - for (auto & it : isDone) - it.second = false; - defaultCycle = originalDefaultCycle; + for (auto & condition : endConditions) + if (condition == EndCondition::CannotMove) + { + if (c.canMoveWordIndex(movement.second)) + return false; + } + + return true; }