#include "Transition.hpp" #include <regex> Transition::Transition(const std::string & name) { std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits { {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("eager_SHIFT"), [this](auto){initEagerShift();}}, {std::regex("standard_SHIFT"), [this](auto){initStandardShift();}}, {std::regex("REDUCE_strict"), [this](auto){initReduce_strict();}}, {std::regex("REDUCE_relaxed"), [this](auto){initReduce_relaxed();}}, {std::regex("eager_LEFT_rel (.+)"), [this](auto sm){(initEagerLeft_rel(sm[1]));}}, {std::regex("eager_RIGHT_rel (.+)"), [this](auto sm){(initEagerRight_rel(sm[1]));}}, {std::regex("standard_LEFT_rel (.+)"), [this](auto sm){(initStandardLeft_rel(sm[1]));}}, {std::regex("standard_RIGHT_rel (.+)"), [this](auto sm){(initStandardRight_rel(sm[1]));}}, {std::regex("eager_LEFT"), [this](auto){(initEagerLeft());}}, {std::regex("eager_RIGHT"), [this](auto){(initEagerRight());}}, {std::regex("deprel (.+)"), [this](auto sm){(initDeprel(sm[1]));}}, {std::regex("EOS b\\.(.+)"), [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), [this](auto){initNothing();}}, {std::regex("IGNORECHAR"), [this](auto){initIgnoreChar();}}, {std::regex("ENDWORD"), [this](auto){initEndWord();}}, {std::regex("ADDCHARTOWORD"), [this](auto){initAddCharToWord();}}, {std::regex("SPLIT (.+)"), [this](auto sm){(initSplit(std::stoi(sm.str(1))));}}, {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"), [this](auto sm) { std::vector<std::string> splitRes{sm[1]}; auto splited = util::split(std::string(sm[2]), '@'); for (auto & s : splited) splitRes.emplace_back(s); initSplitWord(splitRes); }}, }; try { if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm) { this->state = sm[2]; this->name = sm[3]; })) util::myThrow("doesn't match nameRegex"); for (auto & it : inits) if (util::doIfNameMatch(it.first, this->name, it.second)) return; throw std::invalid_argument("no match"); } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));} } void Transition::apply(Config & config) { for (Action & action : sequence) action.apply(config, action); } bool Transition::appliable(const Config & config) const { try { if (!state.empty() && state != config.getState()) return false; for (const Action & action : sequence) if (!action.appliable(config, action)) return false; } catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what())); } return true; } int Transition::getCost(const Config & config) const { try {return cost(config);} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} return 0; } const std::string & Transition::getName() const { return name; } void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Config::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); if (config.getConst(colName, lineIndex, 0) == value) return 0; return 1; }; } void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Config::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); for (auto & part : gold) if (part == value) return 0; return 1; }; } void Transition::initNothing() { cost = [](const Config &) { return 0; }; } void Transition::initIgnoreChar() { sequence.emplace_back(Action::ignoreCurrentCharacter()); cost = [](const Config &) { return 0; }; } void Transition::initEndWord() { sequence.emplace_back(Action::endWord()); cost = [](const Config & config) { if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) return 0; return 1; }; } void Transition::initAddCharToWord() { sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addLinesIfNeeded(0)); sequence.emplace_back(Action::addCurCharToCurWord()); sequence.emplace_back(Action::moveCharacterIndex(1)); cost = [](const Config & config) { if (!config.hasCharacter(config.getCharacterIndex())) return std::numeric_limits<int>::max(); auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get(); auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get(); if (curWord.size() + letter.size() > goldWord.size()) return 1; for (unsigned int i = 0; i < letter.size(); i++) if (goldWord[curWord.size()+i] != letter[i]) return 1; return 0; }; } void Transition::initSplitWord(std::vector<std::string> words) { auto consumedWord = util::splitAsUtf8(words[0]); sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0)); sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0)); sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); cost = [words](const Config & config) { if (!config.isMultiword(config.getWordIndex())) return std::numeric_limits<int>::max(); if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size()) return std::numeric_limits<int>::max(); int cost = 0; for (unsigned int i = 0; i < words.size(); i++) if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) cost++; return cost; }; } void Transition::initSplit(int index) { sequence.emplace_back(Action::split(index)); cost = [index](const Config & config) { auto & transitions = config.getAppliableSplitTransitions(); if (index < 0 or index >= (int)transitions.size()) return std::numeric_limits<int>::max(); return transitions[index]->getCost(config); }; } void Transition::initEagerShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { if (!config.isToken(config.getWordIndex())) return 0; return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); }; } void Transition::initStandardShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { return 0; }; } void Transition::initEagerLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); cost = [label](const Config & config) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getWordIndex(); if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) return 0; int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } void Transition::initStandardLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); sequence.emplace_back(Action::popStack(1)); cost = [label](const Config & config) { auto depIndex = config.getStack(1); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getStack(0); if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) return 0; int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } void Transition::initEagerLeft() { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getWordIndex(); if (depGovIndex == std::to_string(govIndex)) return 0; int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); return cost; }; } void Transition::initEagerRight_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [label](const Config & config) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) return 0; int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } void Transition::initStandardRight_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); cost = [label](const Config & config) { auto govIndex = config.getStack(1); auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) return 0; int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } void Transition::initEagerRight() { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); if (depGovIndex == std::to_string(govIndex)) return 0; int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); return cost; }; } void Transition::initReduce_strict() { sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!config.isToken(stackIndex)) return 0; int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); return cost; }; } void Transition::initReduce_relaxed() { sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); if (!config.isToken(stackIndex)) return 0; int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); return cost; }; } void Transition::initEOS(int bufferIndex) { sequence.emplace_back(Action::setRoot(bufferIndex)); sequence.emplace_back(Action::updateIds(bufferIndex)); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); sequence.emplace_back(Action::emptyStack()); cost = [bufferIndex](const Config & config) { int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex); if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1) return std::numeric_limits<int>::max(); return 0; }; } void Transition::initDeprel(std::string label) { sequence.emplace_back(Action::deprel(label)); cost = [label](const Config & config) { return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; }; } int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { auto govIndex = config.getConst(Config::headColName, withIndex, 0); auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); int nbLinkedWith = 0; for (int i = firstIndex; i <= lastIndex; ++i) { int index = i; if (object == Config::Object::Stack) index = config.getStack(i); if (!config.isToken(index)) continue; auto otherGovIndex = config.getConst(Config::headColName, index, 0); auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) ++nbLinkedWith; if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) ++nbLinkedWith; } return nbLinkedWith; } int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { auto govIndex = config.getConst(Config::headColName, withIndex, 0); auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); int nbLinkedWith = 0; for (int i = firstIndex; i <= lastIndex; ++i) { int index = i; if (object == Config::Object::Stack) index = config.getStack(i); if (!config.isToken(index)) continue; if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) ++nbLinkedWith; } return nbLinkedWith; } int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { int nbLinkedWith = 0; for (int i = firstIndex; i <= lastIndex; ++i) { int index = i; if (object == Config::Object::Stack) index = config.getStack(i); if (!config.isToken(index)) continue; auto otherGovIndex = config.getConst(Config::headColName, index, 0); auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) ++nbLinkedWith; } return nbLinkedWith; } int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config) { int firstIndex = baseIndex; for (int i = baseIndex; config.has(0, i, 0); --i) { if (!config.isToken(i)) continue; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; firstIndex = i; } return firstIndex; } int Transition::getLastIndexOfSentence(int baseIndex, const Config & config) { int lastIndex = baseIndex; for (int i = baseIndex; config.has(0, i, 0); ++i) { if (!config.isToken(i)) continue; lastIndex = i; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) break; } return lastIndex; }