diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 5fcffe2aff0c0b9359594a11a6c1eb0768ff9bdf..03158ad28bc856efecdec769b8de9840ed1c6ac6 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -52,6 +52,8 @@ class Action static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); static Action pushWordIndexOnStack(); static Action popStack(); + static Action emptyStack(); + static Action setRoot(); static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex); }; diff --git a/reading_machine/include/SubConfig.hpp b/reading_machine/include/SubConfig.hpp index 70f0ce589a4826259d8364e944f10de034222200..66fbcc0ae7d0d4babf808517f329fdf59684ec18 100644 --- a/reading_machine/include/SubConfig.hpp +++ b/reading_machine/include/SubConfig.hpp @@ -8,7 +8,7 @@ class SubConfig : public Config { private : - static constexpr std::size_t spanSize = 50; + static constexpr std::size_t spanSize = 200; private : diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 019a077b98f8b9df2c44e302f210ed39c792657d..c4c6e272195788e3f00ba37b2aebced5f8cb813f 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -20,6 +20,7 @@ class Transition void initLeft(std::string label); void initRight(std::string label); void initReduce(); + void initEOS(); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 1cca6d48788229f13685ece03536ad9dc4e1a901..069c3c0207e21c1a008f0de6b1c9a6e313eb63f7 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -169,6 +169,110 @@ Action Action::popStack() return {Type::Pop, apply, undo, appliable}; } +Action Action::emptyStack() +{ + auto apply = [](Config & config, Action & a) + { + while (config.hasStack(0)) + { + a.data.push_back(std::to_string(config.getStack(0))); + config.popStack(); + } + }; + + auto undo = [](Config & config, Action & a) + { + while (a.data.size()) + { + config.addToStack(std::stoi(a.data.back())); + a.data.pop_back(); + } + }; + + auto appliable = [](const Config & config, const Action &) + { + return true; + }; + + return {Type::Pop, apply, undo, appliable}; +} + +Action Action::setRoot() +{ + auto apply = [](Config & config, Action & a) + { + int rootIndex = -1; + + for (int i = config.getWordIndex()-1; true; --i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isToken(i)) + continue; + + if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + break; + + if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i))) + { + rootIndex = i; + a.data.push_back(std::to_string(i)); + } + } + + auto & rootId = config.getLastNotEmptyConst(Config::idColName, rootIndex); + + for (int i = config.getWordIndex()-1; true; --i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isToken(i)) + continue; + + if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + break; + + if (util::isEmpty(config.getLastNotEmptyHyp(Config::headColName, i))) + { + if (i == rootIndex) + { + config.getFirstEmpty(Config::headColName, i) = "0"; + config.getFirstEmpty(Config::deprelColName, i) = "root"; + } + else + { + config.getFirstEmpty(Config::headColName, i) = rootId; + } + } + } + + }; + + auto undo = [](Config & config, Action & a) + { + while (a.data.size()) + { + config.getLastNotEmptyHyp(Config::headColName, std::stoi(a.data.back())) = ""; + a.data.pop_back(); + } + }; + + auto appliable = [](const Config & config, const Action &) + { + return config.hasStack(0); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex) { auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 0e5f6d0ea5579fd3b8b07b8711a0be9a1d03691e..8e6a12f063c36f0463c2833158613af7e645a57b 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -197,7 +197,7 @@ Config::String & Config::getFirstEmpty(int colIndex, int lineIndex) { int baseIndex = getIndexOfLine(lineIndex-getFirstLineIndex()) + getIndexOfCol(colIndex); - for (int i = 0; i <= nbHypothesesMax; ++i) + for (int i = 1; i < nbHypothesesMax; ++i) if (util::isEmpty(lines[baseIndex+i])) return lines[baseIndex+i]; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index bf35c0a036159ae822dcdb5c8937b930aaebe67e..f792b9fc6a536b31be4dcb80e83f8a5e26434e99 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -10,6 +10,7 @@ Transition::Transition(const std::string & name) std::regex reduceRegex("REDUCE"); std::regex leftRegex("LEFT (.+)"); std::regex rightRegex("RIGHT (.+)"); + std::regex eosRegex("EOS"); try { @@ -24,6 +25,8 @@ Transition::Transition(const std::string & name) return; if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);})) return; + if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();})) + return; throw std::invalid_argument("no match"); @@ -248,3 +251,29 @@ void Transition::initReduce() }; } +void Transition::initEOS() +{ + sequence.emplace_back(Action::setRoot()); + sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1)); + sequence.emplace_back(Action::emptyStack()); + + cost = [](const Config & config) + { + if (!config.has(0, config.getStack(0), 0)) + return std::numeric_limits<int>::max(); + + if (!config.isToken(config.getStack(0))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) + ++cost; + + if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, config.getStack(0)))) + ++cost; + + return cost; + }; +} +