From 6a04a795b2da967925c0fae9bb17919ebdbcbd06 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 14 Feb 2020 15:57:16 +0100 Subject: [PATCH] Added action EOS --- reading_machine/include/Action.hpp | 2 + reading_machine/include/SubConfig.hpp | 2 +- reading_machine/include/Transition.hpp | 1 + reading_machine/src/Action.cpp | 104 +++++++++++++++++++++++++ reading_machine/src/Config.cpp | 2 +- reading_machine/src/Transition.cpp | 29 +++++++ 6 files changed, 138 insertions(+), 2 deletions(-) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 5fcffe2..03158ad 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -52,6 +52,8 @@ class Action static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); static Action pushWordIndexOnStack(); static Action popStack(); + static Action emptyStack(); + static Action setRoot(); static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex); }; diff --git a/reading_machine/include/SubConfig.hpp b/reading_machine/include/SubConfig.hpp index 70f0ce5..66fbcc0 100644 --- a/reading_machine/include/SubConfig.hpp +++ b/reading_machine/include/SubConfig.hpp @@ -8,7 +8,7 @@ class SubConfig : public Config { private : - static constexpr std::size_t spanSize = 50; + static constexpr std::size_t spanSize = 200; private : diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 019a077..c4c6e27 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -20,6 +20,7 @@ class Transition void initLeft(std::string label); void initRight(std::string label); void initReduce(); + void initEOS(); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 1cca6d4..069c3c0 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -169,6 +169,110 @@ Action Action::popStack() return {Type::Pop, apply, undo, appliable}; } +Action Action::emptyStack() +{ + auto apply = [](Config & config, Action & a) + { + while (config.hasStack(0)) + { + a.data.push_back(std::to_string(config.getStack(0))); + config.popStack(); + } + }; + + auto undo = [](Config & config, Action & a) + { + while (a.data.size()) + { + config.addToStack(std::stoi(a.data.back())); + a.data.pop_back(); + } + }; + + auto appliable = [](const Config & config, const Action &) + { + return true; + }; + + return {Type::Pop, apply, undo, appliable}; +} + +Action Action::setRoot() +{ + auto apply = [](Config & config, Action & a) + { + int rootIndex = -1; + + for (int i = config.getWordIndex()-1; true; --i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isToken(i)) + continue; + + if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + break; + + if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i))) + { + rootIndex = i; + a.data.push_back(std::to_string(i)); + } + } + + auto & rootId = config.getLastNotEmptyConst(Config::idColName, rootIndex); + + for (int i = config.getWordIndex()-1; true; --i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isToken(i)) + continue; + + if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + break; + + if (util::isEmpty(config.getLastNotEmptyHyp(Config::headColName, i))) + { + if (i == rootIndex) + { + config.getFirstEmpty(Config::headColName, i) = "0"; + config.getFirstEmpty(Config::deprelColName, i) = "root"; + } + else + { + config.getFirstEmpty(Config::headColName, i) = rootId; + } + } + } + + }; + + auto undo = [](Config & config, Action & a) + { + while (a.data.size()) + { + config.getLastNotEmptyHyp(Config::headColName, std::stoi(a.data.back())) = ""; + a.data.pop_back(); + } + }; + + auto appliable = [](const Config & config, const Action &) + { + return config.hasStack(0); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex) { auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 0e5f6d0..8e6a12f 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -197,7 +197,7 @@ Config::String & Config::getFirstEmpty(int colIndex, int lineIndex) { int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); - for (int i = 0; i <= nbHypothesesMax; ++i) + for (int i = 1; i < nbHypothesesMax; ++i) if (util::isEmpty(lines[baseIndex+i])) return lines[baseIndex+i]; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index bf35c0a..f792b9f 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -10,6 +10,7 @@ Transition::Transition(const std::string & name) std::regex reduceRegex("REDUCE"); std::regex leftRegex("LEFT (.+)"); std::regex rightRegex("RIGHT (.+)"); + std::regex eosRegex("EOS"); try { @@ -24,6 +25,8 @@ Transition::Transition(const std::string & name) return; if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);})) return; + if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();})) + return; throw std::invalid_argument("no match"); @@ -248,3 +251,29 @@ void Transition::initReduce() }; } +void Transition::initEOS() +{ + sequence.emplace_back(Action::setRoot()); + sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1)); + sequence.emplace_back(Action::emptyStack()); + + cost = [](const Config & config) + { + if (!config.has(0, config.getStack(0), 0)) + return std::numeric_limits<int>::max(); + + if (!config.isToken(config.getStack(0))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) + ++cost; + + if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, config.getStack(0)))) + ++cost; + + return cost; + }; +} + -- GitLab