Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("eager_SHIFT"),
[this](auto){initEagerShift();}},
{std::regex("standard_SHIFT"),
[this](auto){initStandardShift();}},
{std::regex("REDUCE_strict"),
[this](auto){initReduce_strict();}},
{std::regex("REDUCE_relaxed"),
[this](auto){initReduce_relaxed();}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("standard_LEFT_rel (.+)"),
[this](auto sm){(initStandardLeft_rel(sm[1]));}},
{std::regex("standard_RIGHT_rel (.+)"),
[this](auto sm){(initStandardRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
try
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
} catch (std::exception & e)
{
util::myThrow(fmt::format("transition '{}' {}", name, e.what()));
}
return true;
}
int Transition::getCost(const Config & config) const
{
try {return cost(config);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary
committed
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
Franck Dary
committed
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
void Transition::initStandardShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
return 0;
};
}
void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
cost = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
}
void Transition::initStandardLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1));
cost = [label](const Config & config)
{
auto depIndex = config.getStack(1);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getStack(0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex))
return 0;
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
return cost;
};
}
void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
}
void Transition::initStandardRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
cost = [label](const Config & config)
{
auto govIndex = config.getStack(1);
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex))
return 0;
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
return cost;
};
}
void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
};
}
void Transition::initReduce_relaxed()
{
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
}
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
++nbLinkedWith;
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
++nbLinkedWith;
}
return nbLinkedWith;
}
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
++nbLinkedWith;
}
return nbLinkedWith;
}
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
int firstIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); --i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
firstIndex = i;
}
return firstIndex;
}
int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
int lastIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
lastIndex = i;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
return lastIndex;
}