From 14bcdc4e9a4eb4c161b9013e649eeaff00ef047e Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 2 Mar 2020 22:32:50 +0100 Subject: [PATCH] Added actions for feats prediction --- reading_machine/include/Action.hpp | 2 + reading_machine/include/Transition.hpp | 2 + reading_machine/src/Action.cpp | 73 ++++++++++++++++++++++++++ reading_machine/src/Transition.cpp | 39 ++++++++++++++ 4 files changed, 116 insertions(+) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index a20f68a..994fe9c 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -49,7 +49,9 @@ class Action static Action moveWordIndex(int movement); static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); + static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); + static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); static Action popStack(); static Action emptyStack(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index c3a4589..c7309a6 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -17,11 +17,13 @@ class Transition private : void initWrite(std::string colName, std::string object, std::string index, std::string value); + void initAdd(std::string colName, std::string object, std::string index, std::string value); void initShift(); void initLeft(std::string label); void initRight(std::string label); void initReduce(); void initEOS(); + void initNothing(); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 6cf1678..ca95473 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -87,6 +87,79 @@ Action Action::addHypothesis(const std::string & colName, std::size_t lineIndex, return {Type::Write, apply, undo, appliable}; } +Action Action::addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition) +{ + auto apply = [colName, lineIndex, addition](Config & config, Action &) + { + auto & current = config.getLastNotEmptyHyp(colName, lineIndex); + current = util::isEmpty(current) ? addition : '|' + addition; + }; + + auto undo = [colName, lineIndex](Config & config, Action &) + { + std::string newValue = config.getLastNotEmpty(colName, lineIndex); + while (!newValue.empty() and newValue.back() == '|') + newValue.pop_back(); + if (!newValue.empty()) + newValue.pop_back(); + config.getLastNotEmpty(colName, lineIndex) = newValue; + }; + + auto appliable = [colName, lineIndex, addition](const Config & config, const Action &) + { + if (!config.has(colName, lineIndex, 0)) + return false; + auto & current = config.getLastNotEmptyHypConst(colName, lineIndex); + auto splited = util::split(current.get(), '|'); + for (auto & part : splited) + if (part == addition) + return false; + return true; + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition) +{ + auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else + lineIndex = config.getStack(relativeIndex); + + return addToHypothesis(colName, lineIndex, addition).apply(config, a); + }; + + auto undo = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else + lineIndex = config.getStack(relativeIndex); + + return addToHypothesis(colName, lineIndex, "").undo(config, a); + }; + + auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else if (config.hasStack(relativeIndex)) + lineIndex = config.getStack(relativeIndex); + else + return false; + + return addToHypothesis(colName, lineIndex, addition).appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis) { auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index f18450c..259e7f4 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -5,11 +5,13 @@ Transition::Transition(const std::string & name) { std::regex nameRegex("(<(.+)> )?(.+)"); std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); + std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)"); std::regex shiftRegex("SHIFT"); std::regex reduceRegex("REDUCE"); std::regex leftRegex("LEFT (.+)"); std::regex rightRegex("RIGHT (.+)"); std::regex eosRegex("EOS"); + std::regex nothingRegex("NOTHING"); try { @@ -22,6 +24,8 @@ Transition::Transition(const std::string & name) if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) return; + if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);})) + return; if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();})) return; if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();})) @@ -32,6 +36,8 @@ Transition::Transition(const std::string & name) return; if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();})) return; + if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();})) + return; throw std::invalid_argument("no match"); @@ -89,6 +95,39 @@ void Transition::initWrite(std::string colName, std::string object, std::string }; } +void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) +{ + auto objectValue = Action::str2object(object); + int indexValue = std::stoi(index); + + sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); + + cost = [colName, objectValue, indexValue, value](const Config & config) + { + int lineIndex = 0; + if (objectValue == Action::Object::Buffer) + lineIndex = config.getWordIndex() + indexValue; + else + lineIndex = config.getStack(indexValue); + + auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); + + for (auto & part : gold) + if (part == value) + return 0; + + return 1; + }; +} + +void Transition::initNothing() +{ + cost = [](const Config &) + { + return 0; + }; +} + void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); -- GitLab