diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 3fd0aef7e122d83d7f8dd91feda82347d95ca252..5fcffe2aff0c0b9359594a11a6c1eb0768ff9bdf 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -50,6 +50,9 @@ class Action static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); + static Action pushWordIndexOnStack(); + static Action popStack(); + static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 82b344499c2f782e225bfcd0f298ab0eba54fd3b..daaeae1187c76b2db6e492e48001b0283ad3b84e 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -11,12 +11,13 @@ class Config { - protected : + public : static constexpr const char * EOSColName = "EOS"; static constexpr const char * EOSSymbol1 = "1"; static constexpr const char * EOSSymbol0 = "0"; static constexpr const char * headColName = "HEAD"; + static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; static constexpr int nbHypothesesMax = 1; @@ -86,6 +87,7 @@ class Config util::utf8char getLetter(int letterIndex) const; void addToHistory(const std::string & transition); void addToStack(std::size_t index); + void popStack(); bool isComment(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const; bool isEmptyNode(std::size_t lineIndex) const; diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index cbdf21f8fda73f3f79ec462b104067642f38ecfe..019a077b98f8b9df2c44e302f210ed39c792657d 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -16,6 +16,10 @@ class Transition private : void initWrite(std::string colName, std::string object, std::string index, std::string value); + void initShift(); + void initLeft(std::string label); + void initRight(std::string label); + void initReduce(); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 7dcaf78439d833b87fd439a676fded9cc60392ca..1cca6d48788229f13685ece03536ad9dc4e1a901 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -127,6 +127,87 @@ Action Action::addHypothesisRelative(const std::string & colName, Object object, return {Type::Write, apply, undo, appliable}; } +Action Action::pushWordIndexOnStack() +{ + auto apply = [](Config & config, Action &) + { + config.addToStack(config.getWordIndex()); + }; + + auto undo = [](Config & config, Action &) + { + config.popStack(); + }; + + auto appliable = [](const Config &, const Action &) + { + return true; + }; + + return {Type::Push, apply, undo, appliable}; +} + +Action Action::popStack() +{ + auto apply = [](Config & config, Action & a) + { + auto toSave = config.getStack(0); + a.data.push_back(std::to_string(toSave)); + config.popStack(); + }; + + auto undo = [](Config & config, Action & a) + { + config.addToStack(std::stoi(a.data.back())); + }; + + auto appliable = [](const Config & config, const Action &) + { + return config.hasStack(0); + }; + + return {Type::Pop, apply, undo, appliable}; +} + +Action Action::attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex) +{ + auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) + { + int lineIndex = 0; + if (governorObject == Object::Buffer) + lineIndex = config.getWordIndex() + governorIndex; + else + lineIndex = config.getStack(governorIndex); + addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, config.getLastNotEmptyConst(Config::idColName, lineIndex)).apply(config, a); + }; + + auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) + { + addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a); + }; + + auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action) + { + int govLineIndex = 0; + if (governorObject == Object::Buffer) + { + govLineIndex = config.getWordIndex() + governorIndex; + if (!config.has(0, govLineIndex, 0)) + return false; + } + else + { + if (!config.hasStack(governorIndex)) + return false; + govLineIndex = config.getStack(governorIndex); + } + + return addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, config.getLastNotEmptyConst(Config::idColName, govLineIndex)).appliable(config, action); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action::Object Action::str2object(const std::string & s) { if (s == "b") diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index bd640ffbf40e41cbbf4e2f94e06173e6101ca9b8..40feaf4560bccbb0279f028cbe5c417c056a3bdf 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -262,6 +262,11 @@ void Config::addToStack(std::size_t index) stack.push_back(index); } +void Config::popStack() +{ + stack.pop_back(); +} + bool Config::hasCharacter(int letterIndex) const { return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 46147d4264cc6f10ab4e89d42230e5d9bfcf23c6..a66ba5365d341e788ecc791b93cba32b730c8ac2 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -6,12 +6,24 @@ Transition::Transition(const std::string & name) this->name = name; std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); + std::regex shiftRegex("SHIFT"); + std::regex reduceRegex("REDUCE"); + std::regex leftRegex("LEFT (.+)"); + std::regex rightRegex("RIGHT (.+)"); try { if (util::doIfNameMatch(writeRegex, name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) return; + if (util::doIfNameMatch(shiftRegex, name, [this](auto){initShift();})) + return; + if (util::doIfNameMatch(reduceRegex, name, [this](auto){initReduce();})) + return; + if (util::doIfNameMatch(leftRegex, name, [this](auto sm){initLeft(sm[1]);})) + return; + if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);})) + return; throw std::invalid_argument("no match"); @@ -39,6 +51,11 @@ int Transition::getCost(const Config & config) const return cost(config); } +const std::string & Transition::getName() const +{ + return name; +} + void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Action::str2object(object); @@ -61,8 +78,164 @@ void Transition::initWrite(std::string colName, std::string object, std::string }; } -const std::string & Transition::getName() const +void Transition::initShift() { - return name; + sequence.emplace_back(Action::pushWordIndexOnStack()); + + cost = [](const Config & config) + { + if (config.isToken(config.getWordIndex())) + return 0; + + auto headGov = config.getConst(Config::headColName, config.getWordIndex(), 0); + auto headId = config.getConst(Config::idColName, config.getWordIndex(), 0); + + int cost = 0; + for (int i = 0; config.hasStack(i); ++i) + { + auto stackIndex = config.getStack(i); + auto stackId = config.getConst(Config::idColName, stackIndex, 0); + auto stackGov = config.getConst(Config::headColName, stackIndex, 0); + + if (stackGov == headId || headGov == stackId) + ++cost; + } + + return cost; + }; +} + +void Transition::initLeft(std::string label) +{ + sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label)); + sequence.emplace_back(Action::popStack()); + + cost = [label](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + auto idOfStack = config.getConst(Config::idColName, stackIndex, 0); + auto govIdOfStack = config.getConst(Config::headColName, stackIndex, 0); + + for (int i = wordIndex+1; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + + auto idOfOther = config.getConst(Config::idColName, i, 0); + auto govIdOfOther = config.getConst(Config::headColName, i, 0); + + if (govIdOfStack == idOfOther || govIdOfOther == idOfStack) + ++cost; + } + + //TODO : Check if this is necessary + if (govIdOfStack != config.getConst(Config::idColName, wordIndex, 0)) + ++cost; + + if (label != config.getConst(Config::deprelColName, stackIndex, 0)) + ++cost; + + return cost; + }; +} + +void Transition::initRight(std::string label) +{ + sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label)); + sequence.emplace_back(Action::pushWordIndexOnStack()); + + cost = [label](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + auto idOfBuffer = config.getConst(Config::idColName, wordIndex, 0); + auto govIdOfBuffer = config.getConst(Config::headColName, wordIndex, 0); + + for (int i = wordIndex; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + auto idOfOther = config.getConst(Config::idColName, i, 0); + auto govIdOfOther = config.getConst(Config::headColName, i, 0); + + if (govIdOfBuffer == idOfOther || govIdOfOther == idOfBuffer) + ++cost; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + } + + for (int i = 1; config.hasStack(i); ++i) + { + auto otherStackIndex = config.getStack(i); + auto stackId = config.getConst(Config::idColName, otherStackIndex, 0); + auto stackGov = config.getConst(Config::headColName, otherStackIndex, 0); + + if (stackGov == idOfBuffer || govIdOfBuffer == stackId) + ++cost; + } + + //TODO : Check if this is necessary + if (govIdOfBuffer != config.getConst(Config::idColName, stackIndex, 0)) + ++cost; + + if (label != config.getConst(Config::deprelColName, wordIndex, 0)) + ++cost; + + return cost; + }; +} + +void Transition::initReduce() +{ + sequence.emplace_back(Action::popStack()); + + cost = [](const Config & config) + { + if (config.isToken(config.getWordIndex())) + return 0; + + int cost = 0; + + auto idOfStack = config.getConst(Config::idColName, config.getStack(0), 0); + auto govIdOfStack = config.getConst(Config::headColName, config.getStack(0), 0); + + for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + auto idOfOther = config.getConst(Config::idColName, i, 0); + auto govIdOfOther = config.getConst(Config::headColName, i, 0); + + if (govIdOfStack == idOfOther || govIdOfOther == idOfStack) + ++cost; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + } + + if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + ++cost; + + return cost; + }; }