-
Franck Dary authoredFranck Dary authored
Transition.cpp 32.35 KiB
#include "Transition.hpp"
#include <regex>
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("WRITESCORE ([bs])\\.(.+) (.+)"),
[this](auto sm){(initWriteScore(sm[3], sm[1], sm[2]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("eager_SHIFT"),
[this](auto){initEagerShift();}},
{std::regex("gold_eager_SHIFT"),
[this](auto){initGoldEagerShift();}},
{std::regex("standard_SHIFT"),
[this](auto){initStandardShift();}},
{std::regex("REDUCE_strict"),
[this](auto){initReduce_strict();}},
{std::regex("gold_REDUCE_strict"),
[this](auto){initGoldReduce_strict();}},
{std::regex("REDUCE_relaxed"),
[this](auto){initReduce_relaxed();}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("gold_eager_LEFT_rel (.+)"),
[this](auto sm){(initGoldEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("gold_eager_RIGHT_rel (.+)"),
[this](auto sm){(initGoldEagerRight_rel(sm[1]));}},
{std::regex("standard_LEFT_rel (.+)"),
[this](auto sm){(initStandardLeft_rel(sm[1]));}},
{std::regex("standard_RIGHT_rel (.+)"),
[this](auto sm){(initStandardRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("NOTEOS b\\.(.+)"),
[this](auto sm){initNotEOS(std::stoi(sm[1]));}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD (.+)"),
[this](auto sm){initAddCharToWord(std::stoi(sm.str(1)));}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("TRANSFORMSUFFIX (.+) ([bs])\\.(.+) (.+) ([bs])\\.(.+) (.+)"),
[this](auto sm){(initTransformSuffix(sm[1], sm[2], sm[3], sm[4], sm[5], sm[6], sm[7]));}},
{std::regex("UPPERCASE (.+) ([bs])\\.(.+)"),
[this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}},
{std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
[this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}},
{std::regex("NOTHING (.+) ([bs])\\.(.+)"),
[this](auto sm){(initNothing(sm[1], sm[2], sm[3]));}},
{std::regex("LOWERCASE (.+) ([bs])\\.(.+)"),
[this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}},
{std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
[this](auto sm){(initLowercaseIndex(sm[1], sm[2], sm[3], sm[4]));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
try
{
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
}
void Transition::apply(Config & config, float entropy)
{
if (config.hasColIndex("ENTROPY"))
{
bool mean = true;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
{
if (name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy, mean);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy, mean);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SURPRISAL"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
apply(config);
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
try
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
if (!precondition(config))
return false;
} catch (std::exception & e)
{
util::myThrow(fmt::format("transition '{}' {}", name, e.what()));
}
return true;
}
int Transition::getCostDynamic(const Config & config) const
{
try {return costDynamic(config);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
}
int Transition::getCostStatic(const Config & config) const
{
try {return costStatic(config);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
}
const std::string & Transition::getName() const
{
return name;
}
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Config::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initWriteScore(std::string colName, std::string object, std::string index)
{
auto objectValue = Config::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue));
costDynamic = [](const Config &)
{
return 0;
};
costStatic = costDynamic;
}
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Config::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initNothing()
{
costDynamic = [](const Config &)
{
return 0;
};
costStatic = costDynamic;
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
costDynamic = [](const Config & config)
{
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto goldWord = util::splitAsUtf8(config.getConst("FORM", config.getWordIndex(), 0).get());
auto curWord = util::splitAsUtf8(config.getAsFeature("FORM", config.getWordIndex()).get());
if (curWord.size() >= goldWord.size())
return 0;
return goldWord[curWord.size()] == letter ? std::numeric_limits<int>::max() : 0;
};
costStatic = costDynamic;
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
costDynamic = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return std::numeric_limits<int>::max();
};
costStatic = costDynamic;
}
void Transition::initAddCharToWord(int n)
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0));
sequence.emplace_back(Action::moveCharacterIndex(n));
costDynamic = [n](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()+n-1))
return std::numeric_limits<int>::max();
if (!config.isToken(config.getWordIndex()))
return std::numeric_limits<int>::max();
std::string curWord = config.getAsFeature("FORM", config.getWordIndex());
std::string goldWord = config.getConst("FORM", config.getWordIndex(), 0);
for (int i = 0; i < n; i++)
curWord = fmt::format("{}{}", curWord, config.getLetter(config.getCharacterIndex()+i));
if (curWord.size() > goldWord.size())
return std::numeric_limits<int>::max();
for (unsigned int i = 0; i < curWord.size(); i++)
if (curWord[i] != goldWord[i])
return std::numeric_limits<int>::max();
return std::abs((int)goldWord.size() - (int)curWord.size());
};
costStatic = costDynamic;
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()-1));
sequence.emplace_back(Action::addCharsToCol("FORM", consumedWord.size(), Config::Object::Buffer, 0));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 1; i < words.size(); i++)
{
sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i]));
sequence.emplace_back(Action::copyContent(Config::rawRangeStartColName, Config::Object::Buffer, 0, Config::rawRangeStartColName, Config::Object::Buffer, i));
sequence.emplace_back(Action::copyContent(Config::rawRangeEndColName, Config::Object::Buffer, 0, Config::rawRangeEndColName, Config::Object::Buffer, i));
}
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
costDynamic = [words](const Config & config)
{
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or util::lower(config.getConst("FORM", config.getWordIndex()+i, 0)) != util::lower(words[i]))
return std::numeric_limits<int>::max();
return 0;
};
costStatic = costDynamic;
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
costDynamic = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCostDynamic(config);
};
costStatic = costDynamic;
}
void Transition::initEagerShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [](const Config & config)
{
if (!config.isToken(config.getWordIndex()))
return 0;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
};
costStatic = [](const Config &)
{
return 0;
};
}
void Transition::initGoldEagerShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [](const Config & config)
{
if (!config.isToken(config.getWordIndex()))
return 0;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
};
costStatic = [](const Config &)
{
return 0;
};
precondition = [](const Config & config)
{
if (!config.isToken(config.getWordIndex()))
return true;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config) == 0;
};
}
void Transition::initStandardShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [](const Config &)
{
return 0;
};
costStatic = costDynamic;
}
void Transition::initEagerLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
}
void Transition::initGoldEagerLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
precondition = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost == 0;
};
}
void Transition::initStandardLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1));
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(1);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getStack(0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = costDynamic;
}
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex))
return 0;
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
return cost;
};
costStatic = costDynamic;
}
void Transition::initEagerRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [label](const Config & config)
{
auto depIndex = config.getWordIndex();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
}
void Transition::initGoldEagerRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [label](const Config & config)
{
auto depIndex = config.getWordIndex();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
precondition = [label](const Config & config)
{
auto depIndex = config.getWordIndex();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost == 0;
};
}
void Transition::initStandardRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
{
auto govIndex = config.getStack(1);
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
costStatic = costDynamic;
}
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
costDynamic = [](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex))
return 0;
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
return cost;
};
costStatic = costDynamic;
}
void Transition::initReduce_strict()
{
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return 0;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
};
costStatic = costDynamic;
}
void Transition::initGoldReduce_strict()
{
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return 0;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
};
costStatic = costDynamic;
precondition = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return true;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost == 0;
};
}
void Transition::initReduce_relaxed()
{
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return 0;
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
};
costStatic = costDynamic;
}
void Transition::initEOS(int bufferIndex)
{
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
costDynamic = [bufferIndex](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return 0;
};
costStatic = costDynamic;
}
void Transition::initNotEOS(int bufferIndex)
{
costDynamic = [bufferIndex](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return 0;
};
costStatic = costDynamic;
}
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
costDynamic = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
costStatic = costDynamic;
}
void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule)
{
auto fromObjectValue = Config::str2object(fromObj);
int fromIndexValue = std::stoi(fromIndex);
auto toObjectValue = Config::str2object(toObj);
int toIndexValue = std::stoi(toIndex);
std::string toRemove, toAdd;
util::utf8string toRemoveUtf8, toAddUtf8;
std::size_t index = 0;
for (index = 1; index < rule.size() and rule[index] != '\t'; index++)
toRemove.push_back(rule[index]);
index++;
for (; index < rule.size() and rule[index] != '\t'; index++)
toAdd.push_back(rule[index]);
toRemoveUtf8 = util::splitAsUtf8(toRemove);
toAddUtf8 = util::splitAsUtf8(toAdd);
sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8));
costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config)
{
int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue);
int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue);
util::utf8string res = util::splitAsUtf8(util::lower(config.getAsFeature(fromCol, fromLineIndex).get()));
for (unsigned int i = 0; i < toRemoveUtf8.size(); i++)
res.pop_back();
for (auto & letter : toAddUtf8)
res.push_back(letter);
if (fmt::format("{}", res) == util::lower(config.getConst(toCol, toLineIndex, 0)))
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initUppercase(std::string col, std::string obj, std::string index)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::uppercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 1;
if (util::upper(currentValue) == expectedValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
int inIndexValue = std::stoi(inIndex);
sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 1;
auto currentValueUtf8 = util::splitAsUtf8(currentValue);
util::upper(currentValueUtf8[inIndexValue]);
if (fmt::format("{}", currentValueUtf8) == expectedValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initNothing(std::string col, std::string obj, std::string index)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initLowercase(std::string col, std::string obj, std::string index)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::lowercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 1;
if (util::lower(currentValue) == expectedValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
int inIndexValue = std::stoi(inIndex);
sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 1;
auto currentValueUtf8 = util::splitAsUtf8(currentValue);
util::lower(currentValueUtf8[inIndexValue]);
if (fmt::format("{}", currentValueUtf8) == expectedValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
if (govIndex == std::to_string(index))
++nbLinkedWith;
if (otherGovIndex == std::to_string(withIndex))
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
if (govIndex == std::to_string(index))
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
if (otherGovIndex == std::to_string(withIndex))
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
int firstIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); --i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
firstIndex = i;
}
return firstIndex;
}
int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
int lastIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
lastIndex = i;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
return lastIndex;
}