Skip to content
Snippets Groups Projects
Transition.cpp 17.7 KiB
Newer Older
#include "Transition.hpp"
#include <regex>
Transition::Transition(const std::string & name)
{
  std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
  {
    {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
      [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
    {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
      [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
    {std::regex("eager_SHIFT"),
      [this](auto){initEagerShift();}},
    {std::regex("standard_SHIFT"),
      [this](auto){initStandardShift();}},
    {std::regex("REDUCE_strict"),
      [this](auto){initReduce_strict();}},
    {std::regex("REDUCE_relaxed"),
      [this](auto){initReduce_relaxed();}},
    {std::regex("eager_LEFT_rel (.+)"),
      [this](auto sm){(initEagerLeft_rel(sm[1]));}},
    {std::regex("eager_RIGHT_rel (.+)"),
      [this](auto sm){(initEagerRight_rel(sm[1]));}},
    {std::regex("standard_LEFT_rel (.+)"),
      [this](auto sm){(initStandardLeft_rel(sm[1]));}},
    {std::regex("standard_RIGHT_rel (.+)"),
      [this](auto sm){(initStandardRight_rel(sm[1]));}},
    {std::regex("eager_LEFT"),
      [this](auto){(initEagerLeft());}},
    {std::regex("eager_RIGHT"),
      [this](auto){(initEagerRight());}},
    {std::regex("deprel (.+)"),
      [this](auto sm){(initDeprel(sm[1]));}},
Franck Dary's avatar
Franck Dary committed
    {std::regex("EOS b\\.(.+)"),
      [this](auto sm){initEOS(std::stoi(sm[1]));}},
    {std::regex("NOTHING"),
      [this](auto){initNothing();}},
    {std::regex("IGNORECHAR"),
      [this](auto){initIgnoreChar();}},
    {std::regex("ENDWORD"),
      [this](auto){initEndWord();}},
    {std::regex("ADDCHARTOWORD"),
      [this](auto){initAddCharToWord();}},
Franck Dary's avatar
Franck Dary committed
    {std::regex("SPLIT (.+)"),
      [this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
    {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
      [this](auto sm)
      {
        std::vector<std::string> splitRes{sm[1]};
        auto splited = util::split(std::string(sm[2]), '@');
        for (auto & s : splited)
          splitRes.emplace_back(s);
        initSplitWord(splitRes);
      }},
  };
  if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
        {
          this->state = sm[2];
          this->name = sm[3];
        }))
    util::myThrow("doesn't match nameRegex");
  for (auto & it : inits)
    if (util::doIfNameMatch(it.first, this->name, it.second))
      return;
  } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
void Transition::apply(Config & config)
{
  for (Action & action : sequence)
    action.apply(config, action);
}

bool Transition::appliable(const Config & config) const
{
  try
  {
    if (!state.empty() && state != config.getState())
    for (const Action & action : sequence)
      if (!action.appliable(config, action))
        return false;
  } catch (std::exception & e)
  {
    util::myThrow(fmt::format("transition '{}' {}", name, e.what()));
  }

  return true;
}

int Transition::getCost(const Config & config) const
{
  try {return cost(config);}
  catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}

  return 0;
const std::string & Transition::getName() const
{
  return name;
}

void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary's avatar
Franck Dary committed
  auto objectValue = Config::str2object(object);
  int indexValue = std::stoi(index);

  sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));

  cost = [colName, objectValue, indexValue, value](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
    if (config.getConst(colName, lineIndex, 0) == value)
      return 0;
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary's avatar
Franck Dary committed
  auto objectValue = Config::str2object(object);
  int indexValue = std::stoi(index);

  sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));

  cost = [colName, objectValue, indexValue, value](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);

    auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');

    for (auto & part : gold)
      if (part == value)
        return 0;

    return 1;
  };
}

void Transition::initNothing()
{
  cost = [](const Config &)
  {
    return 0;
  };
}

void Transition::initIgnoreChar()
{
  sequence.emplace_back(Action::ignoreCurrentCharacter());

  cost = [](const Config &)
  {
    return 0;
  };
}

void Transition::initEndWord()
{
  sequence.emplace_back(Action::endWord());

  cost = [](const Config & config)
  {
    if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
      return 0;
    return 1;
  };
}

void Transition::initAddCharToWord()
{
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
  sequence.emplace_back(Action::addLinesIfNeeded(0));
  sequence.emplace_back(Action::addCurCharToCurWord());
  sequence.emplace_back(Action::moveCharacterIndex(1));

  cost = [](const Config & config)
  {
    if (!config.hasCharacter(config.getCharacterIndex()))
      return std::numeric_limits<int>::max();

    auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
    auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
    auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
    if (curWord.size() + letter.size() > goldWord.size())
      return 1;

    for (unsigned int i = 0; i < letter.size(); i++)
      if (goldWord[curWord.size()+i] != letter[i])
        return 1;

    return 0;
  };
}

void Transition::initSplitWord(std::vector<std::string> words)
{
  auto consumedWord = util::splitAsUtf8(words[0]);
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
  sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
  sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
  sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
  for (unsigned int i = 0; i < words.size(); i++)
Franck Dary's avatar
Franck Dary committed
    sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
  sequence.emplace_back(Action::setMultiwordIds(words.size()-1));

  cost = [words](const Config & config)
  {
    if (!config.isMultiword(config.getWordIndex()))
      return std::numeric_limits<int>::max();

    if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
      return std::numeric_limits<int>::max();

    int cost = 0;
    for (unsigned int i = 0; i < words.size(); i++)
      if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
        cost++;

    return cost;
  };
}

Franck Dary's avatar
Franck Dary committed
void Transition::initSplit(int index)
{
  sequence.emplace_back(Action::split(index));

  cost = [index](const Config & config)
  {
    auto & transitions = config.getAppliableSplitTransitions();

    if (index < 0 or index >= (int)transitions.size())
      return std::numeric_limits<int>::max();

    return transitions[index]->getCost(config);
  };
}

void Transition::initEagerShift()
  sequence.emplace_back(Action::pushWordIndexOnStack());
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());

  cost = [](const Config & config)
  {
    if (!config.isToken(config.getWordIndex()))
Franck Dary's avatar
Franck Dary committed
    return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
void Transition::initStandardShift()
{
  sequence.emplace_back(Action::pushWordIndexOnStack());
  sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());

  cost = [](const Config & config)
  {
    return 0;
  };
}

void Transition::initEagerLeft_rel(std::string label)
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::popStack(0));

  cost = [label](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    auto depIndex = config.getStack(0);
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
    auto govIndex = config.getWordIndex();
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
Franck Dary's avatar
Franck Dary committed
    if (label != config.getConst(Config::deprelColName, depIndex, 0))
void Transition::initStandardLeft_rel(std::string label)
{
  sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
  sequence.emplace_back(Action::popStack(1));

  cost = [label](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    auto depIndex = config.getStack(1);
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
    auto govIndex = config.getStack(0);
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
    cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
Franck Dary's avatar
Franck Dary committed
    if (label != config.getConst(Config::deprelColName, depIndex, 0))
void Transition::initEagerLeft()
{
  sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::popStack(0));
Franck Dary's avatar
Franck Dary committed
    auto depIndex = config.getStack(0);
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
    auto govIndex = config.getWordIndex();
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);

    return cost;
  };
}

void Transition::initEagerRight_rel(std::string label)
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
  sequence.emplace_back(Action::pushWordIndexOnStack());
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());

  cost = [label](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    auto govIndex = config.getStack(0);
    auto depIndex = config.getWordIndex();
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
    cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
Franck Dary's avatar
Franck Dary committed
    if (label != config.getConst(Config::deprelColName, depIndex, 0))
void Transition::initStandardRight_rel(std::string label)
{
  sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
  sequence.emplace_back(Action::popStack(0));

  cost = [label](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    auto govIndex = config.getStack(1);
    auto depIndex = config.getStack(0);
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
    cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
Franck Dary's avatar
Franck Dary committed
    if (label != config.getConst(Config::deprelColName, depIndex, 0))
void Transition::initEagerRight()
{
  sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
  sequence.emplace_back(Action::pushWordIndexOnStack());
  sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());

  cost = [](const Config & config)
  {
Franck Dary's avatar
Franck Dary committed
    auto govIndex = config.getStack(0);
    auto depIndex = config.getWordIndex();
    auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
Franck Dary's avatar
Franck Dary committed
    if (depGovIndex == std::to_string(govIndex))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
    cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
void Transition::initReduce_strict()
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::popStack(0));
  cost = [](const Config & config)
  {
    auto stackIndex = config.getStack(0);
Franck Dary's avatar
Franck Dary committed
    auto wordIndex = config.getWordIndex();
Franck Dary's avatar
Franck Dary committed
    if (!config.isToken(stackIndex))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);

    return cost;
  };
}

void Transition::initReduce_relaxed()
{
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::popStack(0));
  cost = [](const Config & config)
  {
    auto stackIndex = config.getStack(0);
Franck Dary's avatar
Franck Dary committed
    auto wordIndex = config.getWordIndex();
Franck Dary's avatar
Franck Dary committed
    if (!config.isToken(stackIndex))
      return 0;
Franck Dary's avatar
Franck Dary committed
    int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
Franck Dary's avatar
Franck Dary committed
void Transition::initEOS(int bufferIndex)
Franck Dary's avatar
Franck Dary committed
{
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::setRoot(bufferIndex));
  sequence.emplace_back(Action::updateIds(bufferIndex));
  sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
Franck Dary's avatar
Franck Dary committed
  sequence.emplace_back(Action::emptyStack());
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  cost = [bufferIndex](const Config & config)
Franck Dary's avatar
Franck Dary committed
  {
Franck Dary's avatar
Franck Dary committed
    int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
    if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
      return std::numeric_limits<int>::max();
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    return 0;
void Transition::initDeprel(std::string label)
{
  sequence.emplace_back(Action::deprel(label));

  cost = [label](const Config & config)
  {
    return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
  };
}

Franck Dary's avatar
Franck Dary committed
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
  auto govIndex = config.getConst(Config::headColName, withIndex, 0);
Franck Dary's avatar
Franck Dary committed
  auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
Franck Dary's avatar
Franck Dary committed

  int nbLinkedWith = 0;

  for (int i = firstIndex; i <= lastIndex; ++i)
  {
    int index = i;
    if (object == Config::Object::Stack)
      index = config.getStack(i);

    if (!config.isToken(index))
      continue;

    auto otherGovIndex = config.getConst(Config::headColName, index, 0);
Franck Dary's avatar
Franck Dary committed
    auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
Franck Dary's avatar
Franck Dary committed
    if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
      ++nbLinkedWith;
    if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
Franck Dary's avatar
Franck Dary committed
      ++nbLinkedWith;
  }

  return nbLinkedWith;
}

Franck Dary's avatar
Franck Dary committed
int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
  auto govIndex = config.getConst(Config::headColName, withIndex, 0);
  auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);

  int nbLinkedWith = 0;

  for (int i = firstIndex; i <= lastIndex; ++i)
  {
    int index = i;
    if (object == Config::Object::Stack)
      index = config.getStack(i);

    if (!config.isToken(index))
      continue;

    if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
      ++nbLinkedWith;
  }

  return nbLinkedWith;
}

int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
  int nbLinkedWith = 0;

  for (int i = firstIndex; i <= lastIndex; ++i)
  {
    int index = i;
    if (object == Config::Object::Stack)
      index = config.getStack(i);

    if (!config.isToken(index))
      continue;

    auto otherGovIndex = config.getConst(Config::headColName, index, 0);
    auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);

    if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
      ++nbLinkedWith;
  }

  return nbLinkedWith;
}

Franck Dary's avatar
Franck Dary committed
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
  int firstIndex = baseIndex;

  for (int i = baseIndex; config.has(0, i, 0); --i)
  {
    if (!config.isToken(i))
      continue;

    if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
      break;

    firstIndex = i;
  }

  return firstIndex;
}

int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
  int lastIndex = baseIndex;

  for (int i = baseIndex; config.has(0, i, 0); ++i)
  {
    if (!config.isToken(i))
      continue;

    lastIndex = i;

    if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
      break;
  }

  return lastIndex;
}