Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("eager_SHIFT"),
[this](auto){initEagerShift();}},
{std::regex("standard_SHIFT"),
[this](auto){initStandardShift();}},
{std::regex("REDUCE_strict"),
[this](auto){initReduce_strict();}},
{std::regex("REDUCE_relaxed"),
[this](auto){initReduce_relaxed();}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("standard_LEFT_rel (.+)"),
[this](auto sm){(initStandardLeft_rel(sm[1]));}},
{std::regex("standard_RIGHT_rel (.+)"),
[this](auto sm){(initStandardRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary
committed
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
Franck Dary
committed
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
void Transition::initStandardShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return 0;
};
}
void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
void Transition::initStandardLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1));
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(1);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getStack(0);
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, stackIndex, config);
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (stackGovIndex != std::to_string(wordIndex))
++cost;
return cost;
};
}
void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
void Transition::initStandardRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(1);
auto wordIndex = config.getStack(0);
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, wordIndex, config);
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
return cost;
};
}
void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
++cost;
return cost;
};
}
void Transition::initReduce_relaxed()
{
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
}
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
++nbLinkedWith;
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
int firstIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); --i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
firstIndex = i;
}
return firstIndex;
}
int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
int lastIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
lastIndex = i;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
return lastIndex;
}