#include "Transition.hpp" #include <regex> Transition::Transition(const std::string & name) { this->name = name; std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); std::regex shiftRegex("SHIFT"); std::regex reduceRegex("REDUCE"); std::regex leftRegex("LEFT (.+)"); std::regex rightRegex("RIGHT (.+)"); std::regex eosRegex("EOS"); try { if (util::doIfNameMatch(writeRegex, name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) return; if (util::doIfNameMatch(shiftRegex, name, [this](auto){initShift();})) return; if (util::doIfNameMatch(reduceRegex, name, [this](auto){initReduce();})) return; if (util::doIfNameMatch(leftRegex, name, [this](auto sm){initLeft(sm[1]);})) return; if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);})) return; if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();})) return; throw std::invalid_argument("no match"); } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", name, e.what()));} } void Transition::apply(Config & config) { for (Action & action : sequence) action.apply(config, action); } bool Transition::appliable(const Config & config) const { for (const Action & action : sequence) if (!action.appliable(config, action)) return false; return true; } int Transition::getCost(const Config & config) const { return cost(config); } const std::string & Transition::getName() const { return name; } void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Action::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = 0; if (objectValue == Action::Object::Buffer) lineIndex = config.getWordIndex() + indexValue; else lineIndex = config.getStack(indexValue); if (config.getConst(colName, lineIndex, 0) == value) return 0; return 1; }; } void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); cost = [](const Config & config) { if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) return std::numeric_limits<int>::max(); if (!config.isToken(config.getWordIndex())) return 0; auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0); int cost = 0; for (int i = 0; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto stackIndex = config.getStack(i); auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex)) ++cost; } return cost; }; } void Transition::initLeft(std::string label) { sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack()); cost = [label](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); int cost = 0; auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); for (int i = wordIndex+1; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; } //TODO : Check if this is necessary if (stackGovIndex != std::to_string(wordIndex)) ++cost; if (label != config.getConst(Config::deprelColName, stackIndex, 0)) ++cost; return cost; }; } void Transition::initRight(std::string label) { sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); cost = [label](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); int cost = 0; auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); for (int i = wordIndex; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; } for (int i = 1; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto otherStackIndex = config.getStack(i); auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) ++cost; } //TODO : Check if this is necessary if (bufferGovIndex != std::to_string(stackIndex)) ++cost; if (label != config.getConst(Config::deprelColName, wordIndex, 0)) ++cost; return cost; }; } void Transition::initReduce() { sequence.emplace_back(Action::popStack()); cost = [](const Config & config) { if (!config.has(0, config.getStack(0), 0)) return 0; if (!config.isToken(config.getStack(0))) return 0; int cost = 0; auto stackIndex = config.getStack(0); auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0); for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; } if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) ++cost; return cost; }; } void Transition::initEOS() { sequence.emplace_back(Action::setRoot()); sequence.emplace_back(Action::updateIds()); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1)); sequence.emplace_back(Action::emptyStack()); cost = [](const Config & config) { if (!config.has(0, config.getStack(0), 0)) return std::numeric_limits<int>::max(); if (!config.isToken(config.getStack(0))) return std::numeric_limits<int>::max(); if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) return std::numeric_limits<int>::max(); int cost = 0; --cost; for (int i = 0; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto otherStackIndex = config.getStack(i); auto otherStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex); if (util::isEmpty(otherStackGovPred)) ++cost; } return cost; }; }