#include "Transition.hpp" #include <regex> Transition::Transition(const std::string & name) { std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits { {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("SHIFT"), [this](auto){initShift();}}, {std::regex("REDUCE"), [this](auto){initReduce();}}, {std::regex("LEFT (.+)"), [this](auto sm){(initLeft(sm[1]));}}, {std::regex("RIGHT (.+)"), [this](auto sm){(initRight(sm[1]));}}, {std::regex("EOS"), [this](auto){initEOS();}}, {std::regex("NOTHING"), [this](auto){initNothing();}}, {std::regex("IGNORECHAR"), [this](auto){initIgnoreChar();}}, {std::regex("ENDWORD"), [this](auto){initEndWord();}}, {std::regex("ADDCHARTOWORD"), [this](auto){initAddCharToWord();}}, {std::regex("SPLIT (.+)"), [this](auto sm){(initSplit(std::stoi(sm.str(1))));}}, {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"), [this](auto sm) { std::vector<std::string> splitRes{sm[1]}; auto splited = util::split(std::string(sm[2]), '@'); for (auto & s : splited) splitRes.emplace_back(s); initSplitWord(splitRes); }}, }; try { if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm) { this->state = sm[2]; this->name = sm[3]; })) util::myThrow("doesn't match nameRegex"); for (auto & it : inits) if (util::doIfNameMatch(it.first, this->name, it.second)) return; throw std::invalid_argument("no match"); } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));} } void Transition::apply(Config & config) { for (Action & action : sequence) action.apply(config, action); } bool Transition::appliable(const Config & config) const { if (!state.empty() && state != config.getState()) return false; for (const Action & action : sequence) if (!action.appliable(config, action)) return false; return true; } int Transition::getCost(const Config & config) const { return cost(config); } const std::string & Transition::getName() const { return name; } void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Action::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = 0; if (objectValue == Action::Object::Buffer) lineIndex = config.getWordIndex() + indexValue; else lineIndex = config.getStack(indexValue); if (config.getConst(colName, lineIndex, 0) == value) return 0; return 1; }; } void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Action::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = 0; if (objectValue == Action::Object::Buffer) lineIndex = config.getWordIndex() + indexValue; else lineIndex = config.getStack(indexValue); auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); for (auto & part : gold) if (part == value) return 0; return 1; }; } void Transition::initNothing() { cost = [](const Config &) { return 0; }; } void Transition::initIgnoreChar() { sequence.emplace_back(Action::ignoreCurrentCharacter()); cost = [](const Config &) { return 0; }; } void Transition::initEndWord() { sequence.emplace_back(Action::endWord()); cost = [](const Config & config) { if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) return 0; return 1; }; } void Transition::initAddCharToWord() { sequence.emplace_back(Action::assertIsEmpty(Config::idColName)); sequence.emplace_back(Action::addLinesIfNeeded(0)); sequence.emplace_back(Action::addCurCharToCurWord()); sequence.emplace_back(Action::moveCharacterIndex(1)); cost = [](const Config & config) { if (!config.hasCharacter(config.getCharacterIndex())) return std::numeric_limits<int>::max(); auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get(); auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get(); if (curWord.size() + letter.size() > goldWord.size()) return 1; for (unsigned int i = 0; i < letter.size(); i++) if (goldWord[curWord.size()+i] != letter[i]) return 1; return 0; }; } void Transition::initSplitWord(std::vector<std::string> words) { auto consumedWord = util::splitAsUtf8(words[0]); sequence.emplace_back(Action::assertIsEmpty(Config::idColName)); sequence.emplace_back(Action::assertIsEmpty("FORM")); sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); cost = [words](const Config & config) { if (!config.isMultiword(config.getWordIndex())) return std::numeric_limits<int>::max(); if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size()) return std::numeric_limits<int>::max(); int cost = 0; for (unsigned int i = 0; i < words.size(); i++) if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) cost++; return cost; }; } void Transition::initSplit(int index) { sequence.emplace_back(Action::split(index)); cost = [index](const Config & config) { auto & transitions = config.getAppliableSplitTransitions(); if (index < 0 or index >= (int)transitions.size()) return std::numeric_limits<int>::max(); return transitions[index]->getCost(config); }; } void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); cost = [](const Config & config) { if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) return std::numeric_limits<int>::max(); if (!config.isToken(config.getWordIndex())) return 0; auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0); int cost = 0; for (int i = 0; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto stackIndex = config.getStack(i); auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex)) ++cost; } return cost; }; } void Transition::initLeft(std::string label) { sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack()); cost = [label](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); int cost = 0; auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); for (int i = wordIndex+1; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; } //TODO : Check if this is necessary if (stackGovIndex != std::to_string(wordIndex)) ++cost; if (label != config.getConst(Config::deprelColName, stackIndex, 0)) ++cost; return cost; }; } void Transition::initRight(std::string label) { sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); cost = [label](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); int cost = 0; auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); for (int i = wordIndex; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (bufferGovIndex == std::to_string(i)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; } for (int i = 1; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto otherStackIndex = config.getStack(i); auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) ++cost; } //TODO : Check if this is necessary if (bufferGovIndex != std::to_string(stackIndex)) ++cost; if (label != config.getConst(Config::deprelColName, wordIndex, 0)) ++cost; return cost; }; } void Transition::initReduce() { sequence.emplace_back(Action::popStack()); cost = [](const Config & config) { if (!config.isToken(config.getStack(0))) return 0; int cost = 0; auto stackIndex = config.getStack(0); auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; auto otherGovIndex = config.getConst(Config::headColName, i, 0); if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; } if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) ++cost; return cost; }; } void Transition::initEOS() { sequence.emplace_back(Action::setRoot()); sequence.emplace_back(Action::updateIds()); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1)); sequence.emplace_back(Action::emptyStack()); cost = [](const Config & config) { if (!config.has(0, config.getStack(0), 0)) return std::numeric_limits<int>::max(); if (!config.isToken(config.getStack(0))) return std::numeric_limits<int>::max(); if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) return std::numeric_limits<int>::max(); int cost = 0; --cost; for (int i = 0; config.hasStack(i); ++i) { if (!config.has(0, config.getStack(i), 0)) continue; auto otherStackIndex = config.getStack(i); auto otherStackGovPred = config.getAsFeature(Config::headColName, otherStackIndex); if (util::isEmpty(otherStackGovPred)) ++cost; } return cost; }; }