diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp index 0d81f427ff56fa7a98d39d72498f41196776a24f..b5291b7d3c96261056968960ba792c3f9026e2e6 100644 --- a/transition_machine/include/ActionBank.hpp +++ b/transition_machine/include/ActionBank.hpp @@ -95,11 +95,29 @@ class ActionBank /// /// \param tapeName The tape we will write to /// \param value The value we will write - /// \param relativeIndex The write index relmative to the buffer's head + /// \param relativeIndex The write index relative to the buffer's head /// /// \return A BasicAction doing all of that static Action::BasicAction bufferWrite(std::string tapeName, std::string value, int relativeIndex); + /// \brief Append a string to a cell in the buffer + /// + /// \param tapeName The tape we will write to + /// \param value The value we will append + /// \param relativeIndex The write index relative to the buffer's head + /// + /// \return A BasicAction doing all of that + static Action::BasicAction bufferAdd(std::string tapeName, std::string value, int relativeIndex); + + /// \brief Append a string to a cell in the stack + /// + /// \param tapeName The tape we will write to + /// \param value The value we will append + /// \param relativeIndex The write index relative to the stack's head + /// + /// \return A BasicAction doing all of that + static Action::BasicAction stackAdd(std::string tapeName, std::string value, int bufferIndex); + /// \brief Move the buffer's head /// /// \param movement The relative movement of the buffer's head @@ -129,6 +147,11 @@ class ActionBank /// \return A BasicAction only appliable if the tape at relativeIndex is not empty. static Action::BasicAction checkNotEmpty(std::string tape, int relativeIndex); + /// \brief Verify that the config is not final. + /// + /// \return A BasicAction only appliable if the config is not final. + static Action::BasicAction checkConfigIsNotFinal(); + /// \brief Write something on the buffer /// /// \param tapeName The tape we will write to diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 047466c8a208a4c25fd1cf2f74ce0ae89a26463a..a5ee86f36978ebae657de939b819bbdb0ac6ae2c 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -75,6 +75,22 @@ Action::BasicAction ActionBank::checkRawInputHeadIsSpace() return basicAction; } +Action::BasicAction ActionBank::checkConfigIsNotFinal() +{ + auto apply = [](Config &, Action::BasicAction &) + {}; + auto undo = [](Config &, Action::BasicAction &) + {}; + auto appliable = [](Config & c, Action::BasicAction &) + { + return !c.isFinal(); + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + return basicAction; +} + Action::BasicAction ActionBank::checkRawInputHeadIsSeparator() { auto apply = [](Config &, Action::BasicAction &) @@ -163,6 +179,108 @@ Action::BasicAction ActionBank::stackWrite(std::string tapeName, std::string val return basicAction; } +Action::BasicAction ActionBank::bufferAdd(std::string tapeName, std::string value, int relativeIndex) +{ + auto apply = [tapeName, value, relativeIndex](Config & config, Action::BasicAction &) + { + auto & tape = config.getTape(tapeName); + auto & from = tape.getHyp(relativeIndex); + + auto parts = util::split(from, '|'); + parts.emplace_back(value); + + std::sort(parts.begin(), parts.end()); + + std::string newValue; + for (auto & part : parts) + newValue += part + '|'; + + newValue.pop_back(); + + tape.setHyp(relativeIndex, newValue); + }; + auto undo = [tapeName, relativeIndex](Config & config, Action::BasicAction &) + { + auto & tape = config.getTape(tapeName); + auto from = tape.getHyp(relativeIndex); + while (!from.empty() && from.back() != '|') + from.pop_back(); + if (!from.empty() && from.back() == '|') + from.pop_back(); + + tape.setHyp(relativeIndex, from); + }; + auto appliable = [tapeName, relativeIndex, value](Config & config, Action::BasicAction &) + { + if (config.isFinal()) + return false; + + auto & tape = config.getTape(tapeName); + auto & from = tape.getHyp(relativeIndex); + + auto splited = util::split(from, '|'); + for (auto & part : splited) + if (part == value) + return false; + + return true; + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, value, apply, undo, appliable}; + + return basicAction; +} + +Action::BasicAction ActionBank::stackAdd(std::string tapeName, std::string value, int stackIndex) +{ + auto apply = [tapeName, value, stackIndex](Config & c, Action::BasicAction &) + { + int bufferIndex = c.stackGetElem(stackIndex); + int relativeIndex = bufferIndex - c.getHead(); + auto & tape = c.getTape(tapeName); + auto & from = tape.getHyp(relativeIndex); + + if (!from.empty()) + tape.setHyp(relativeIndex, from+'|'+value); + else + tape.setHyp(relativeIndex, value); + }; + auto undo = [tapeName, stackIndex](Config & c, Action::BasicAction &) + { + int bufferIndex = c.stackGetElem(stackIndex); + int relativeIndex = bufferIndex - c.getHead(); + auto & tape = c.getTape(tapeName); + auto from = tape.getHyp(relativeIndex); + while (!from.empty() && from.back() != '|') + from.pop_back(); + if (!from.empty() && from.back() == '|') + from.pop_back(); + + tape.setHyp(relativeIndex, from); + }; + auto appliable = [tapeName, stackIndex, value](Config & c, Action::BasicAction &) + { + if (c.isFinal()) + return false; + + int bufferIndex = c.stackGetElem(stackIndex); + int relativeIndex = bufferIndex - c.getHead(); + auto & tape = c.getTape(tapeName); + auto & from = tape.getHyp(relativeIndex); + + auto splited = util::split(from, '|'); + for (auto & part : splited) + if (part == value) + return false; + + return true; + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, value, apply, undo, appliable}; + + return basicAction; +} + Action::BasicAction ActionBank::pushHead() { auto apply = [](Config & c, Action::BasicAction &) @@ -253,6 +371,25 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na else if (object[0] == "s") sequence.emplace_back(stackWrite(tapeName, value, relativeIndex)); } + else if(std::string(b1) == "ADD") + { + if (sscanf(name.c_str(), "%s %s %s %s", b1, b4, b2, b3) != 4) + invalidNameAndAbort(ERRINFO); + + std::string tapeName(b2); + std::string value(b3); + auto object = util::split(b4, '.'); + + if (object.size() != 2) + invalidNameAndAbort(ERRINFO); + + int relativeIndex = std::stoi(object[1]); + + if (object[0] == "b") + sequence.emplace_back(bufferAdd(tapeName, value, relativeIndex)); + else if (object[0] == "s") + sequence.emplace_back(stackAdd(tapeName, value, relativeIndex)); + } else if(std::string(b1) == "MULTIWRITE") { int startRelIndex; @@ -293,6 +430,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } else if(std::string(b1) == "NOTHING") { + sequence.emplace_back(checkConfigIsNotFinal()); } else if(std::string(b1) == "EPSILON") { diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index baf0789e09c4578737c423d95a04ba8f385ad148..4577984dc2c0d22e235bdb77bf50aedca739f842 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -286,10 +286,41 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) ? 0 : 1; + if (!strncmp("WRITE", action.c_str(), 5)) + return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) ? 0 : 1; + + auto partsRef = util::split(c.getTape("MORPHO").getRef(0), '|'); + auto partsHyp = util::split(c.getTape("MORPHO").getHyp(0), '|'); + + if (!strncmp("NOTHING", action.c_str(), 7)) + { + int diff = std::abs((int)(partsRef.size()-partsHyp.size())); + return partsRef == partsHyp ? 0 : diff; + } + + if (strncmp("ADD", action.c_str(), 3)) + return 1; + + auto actionPart = util::split(action, ' ').back(); + + std::set<std::string> presentHyp; + std::set<std::string> presentRef; + for (auto & part : partsHyp) + presentHyp.insert(part); + for (auto & part : partsRef) + presentRef.insert(part); + + int cost = 0; + + if (!presentRef.count(actionPart)) + cost++; + if (presentHyp.count(actionPart)) + cost++; + + return cost; }))); - str2oracle.emplace("strategy_morpho", std::unique_ptr<Oracle>(new Oracle( + str2oracle.emplace("strategy_morpho_whole", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, @@ -302,6 +333,23 @@ void Oracle::createDatabase() return 0; }))); + str2oracle.emplace("strategy_morpho_parts", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config & c, Oracle *) + { + std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name); + if (previousAction == "nothing") + return std::string("MOVE morpho 1"); + + return std::string("MOVE morpho 0"); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("strategy_tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) {