Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("SHIFT"),
[this](auto){initShift();}},
{std::regex("REDUCE_strict"),
[this](auto){initReduce_strict();}},
{std::regex("REDUCE_relaxed"),
[this](auto){initReduce_relaxed();}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary
committed
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
Franck Dary
committed
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
int cost = 0;
for (int i = 0; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto stackIndex = config.getStack(i);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
++cost;
}
return cost;
};
}
void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
return cost;
};
}
void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
return cost;
};
}
void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack());
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
cost = [](const Config & config)
{
if (!config.isToken(config.getStack(0)))
return 0;
int cost = 0;
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
++cost;
return cost;
};
}
void Transition::initReduce_relaxed()
{
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
Franck Dary
committed
if (!config.isToken(config.getStack(0)))
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
++cost;
return cost;
};
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
}