Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("SHIFT"),
[this](auto){initShift();}},
{std::regex("REDUCE_strict"),
[this](auto){initReduce_strict();}},
{std::regex("REDUCE_relaxed"),
[this](auto){initReduce_relaxed();}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary
committed
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
Franck Dary
committed
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (stackGovIndex != std::to_string(wordIndex))
++cost;
return cost;
};
}
void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
return cost;
};
}
void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
++cost;
return cost;
};
}
void Transition::initReduce_relaxed()
{
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
}
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
++nbLinkedWith;
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
int firstIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); --i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
firstIndex = i;
}
return firstIndex;
}
int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
int lastIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
lastIndex = i;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
return lastIndex;
}