Skip to content
Snippets Groups Projects
Transition.cpp 15.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "Transition.hpp"
    
    #include <regex>
    
    Transition::Transition(const std::string & name)
    {
    
      std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
      {
        {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
          [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
        {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
          [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
        {std::regex("SHIFT"),
          [this](auto){initShift();}},
    
        {std::regex("REDUCE_strict"),
          [this](auto){initReduce_strict();}},
        {std::regex("REDUCE_relaxed"),
          [this](auto){initReduce_relaxed();}},
    
        {std::regex("eager_LEFT_rel (.+)"),
          [this](auto sm){(initEagerLeft_rel(sm[1]));}},
        {std::regex("eager_RIGHT_rel (.+)"),
          [this](auto sm){(initEagerRight_rel(sm[1]));}},
        {std::regex("eager_LEFT"),
          [this](auto){(initEagerLeft());}},
        {std::regex("eager_RIGHT"),
          [this](auto){(initEagerRight());}},
        {std::regex("deprel (.+)"),
          [this](auto sm){(initDeprel(sm[1]));}},
    
    Franck Dary's avatar
    Franck Dary committed
        {std::regex("EOS b\\.(.+)"),
          [this](auto sm){initEOS(std::stoi(sm[1]));}},
    
        {std::regex("NOTHING"),
          [this](auto){initNothing();}},
        {std::regex("IGNORECHAR"),
          [this](auto){initIgnoreChar();}},
        {std::regex("ENDWORD"),
          [this](auto){initEndWord();}},
        {std::regex("ADDCHARTOWORD"),
          [this](auto){initAddCharToWord();}},
    
    Franck Dary's avatar
    Franck Dary committed
        {std::regex("SPLIT (.+)"),
          [this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
    
        {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
          [this](auto sm)
          {
            std::vector<std::string> splitRes{sm[1]};
            auto splited = util::split(std::string(sm[2]), '@');
            for (auto & s : splited)
              splitRes.emplace_back(s);
            initSplitWord(splitRes);
          }},
      };
    
      if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
    
            {
              this->state = sm[2];
              this->name = sm[3];
            }))
        util::myThrow("doesn't match nameRegex");
    
      for (auto & it : inits)
        if (util::doIfNameMatch(it.first, this->name, it.second))
          return;
    
      } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
    
    void Transition::apply(Config & config)
    {
      for (Action & action : sequence)
        action.apply(config, action);
    }
    
    bool Transition::appliable(const Config & config) const
    {
    
      if (!state.empty() && state != config.getState())
        return false;
    
    
      for (const Action & action : sequence)
        if (!action.appliable(config, action))
          return false;
    
      return true;
    }
    
    int Transition::getCost(const Config & config) const
    {
      return cost(config);
    }
    
    
    const std::string & Transition::getName() const
    {
      return name;
    }
    
    
    void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      auto objectValue = Config::str2object(object);
    
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
    
        if (config.getConst(colName, lineIndex, 0) == value)
          return 0;
    
    void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      auto objectValue = Config::str2object(object);
    
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
    
    
        auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
    
        for (auto & part : gold)
          if (part == value)
            return 0;
    
        return 1;
      };
    }
    
    void Transition::initNothing()
    {
      cost = [](const Config &)
      {
        return 0;
      };
    }
    
    
    void Transition::initIgnoreChar()
    {
      sequence.emplace_back(Action::ignoreCurrentCharacter());
    
      cost = [](const Config &)
      {
        return 0;
      };
    }
    
    void Transition::initEndWord()
    {
      sequence.emplace_back(Action::endWord());
    
      cost = [](const Config & config)
      {
        if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
          return 0;
        return 1;
      };
    }
    
    void Transition::initAddCharToWord()
    {
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
    
      sequence.emplace_back(Action::addLinesIfNeeded(0));
      sequence.emplace_back(Action::addCurCharToCurWord());
      sequence.emplace_back(Action::moveCharacterIndex(1));
    
      cost = [](const Config & config)
      {
        if (!config.hasCharacter(config.getCharacterIndex()))
          return std::numeric_limits<int>::max();
    
        auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
        auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
    
        auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
    
        if (curWord.size() + letter.size() > goldWord.size())
          return 1;
    
        for (unsigned int i = 0; i < letter.size(); i++)
          if (goldWord[curWord.size()+i] != letter[i])
            return 1;
    
        return 0;
      };
    }
    
    void Transition::initSplitWord(std::vector<std::string> words)
    {
    
      auto consumedWord = util::splitAsUtf8(words[0]);
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
    
      sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
      sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
      for (unsigned int i = 0; i < words.size(); i++)
    
    Franck Dary's avatar
    Franck Dary committed
        sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
    
      sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
    
    
      cost = [words](const Config & config)
      {
    
        if (!config.isMultiword(config.getWordIndex()))
          return std::numeric_limits<int>::max();
    
        if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
          return std::numeric_limits<int>::max();
    
    
        int cost = 0;
        for (unsigned int i = 0; i < words.size(); i++)
          if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
            cost++;
    
        return cost;
      };
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initSplit(int index)
    {
      sequence.emplace_back(Action::split(index));
    
      cost = [index](const Config & config)
      {
        auto & transitions = config.getAppliableSplitTransitions();
    
        if (index < 0 or index >= (int)transitions.size())
          return std::numeric_limits<int>::max();
    
        return transitions[index]->getCost(config);
      };
    }
    
    
    void Transition::initShift()
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
    
      cost = [](const Config & config)
      {
    
        if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
    
        if (!config.isToken(config.getWordIndex()))
    
        auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
    
    
        int cost = 0;
        for (int i = 0; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto stackIndex = config.getStack(i);
    
          auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
          if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
    
    void Transition::initEagerLeft_rel(std::string label)
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
    
      sequence.emplace_back(Action::popStack());
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
    
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
    
        for (int i = wordIndex+1; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
    
            ++cost;
        }
    
        //TODO : Check if this is necessary
    
        if (stackGovIndex != std::to_string(wordIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, stackIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initEagerLeft()
    {
      sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
      sequence.emplace_back(Action::popStack());
    
      cost = [](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
        for (int i = wordIndex+1; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
            ++cost;
        }
    
        //TODO : Check if this is necessary
        if (stackGovIndex != std::to_string(wordIndex))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initEagerRight_rel(std::string label)
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
    
        for (int i = wordIndex; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
    Franck Dary's avatar
    Franck Dary committed
          if (bufferGovIndex == std::to_string(i))
    
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        for (int i = 1; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto otherStackIndex = config.getStack(i);
    
          auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
    
          if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
    
            ++cost;
        }
    
        //TODO : Check if this is necessary
    
        if (bufferGovIndex != std::to_string(stackIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, wordIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initEagerRight()
    {
      sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::pushWordIndexOnStack());
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
      cost = [](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
        for (int i = wordIndex; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (bufferGovIndex == std::to_string(i))
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        for (int i = 1; config.hasStack(i); ++i)
        {
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
          auto otherStackIndex = config.getStack(i);
          auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
    
          if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
            ++cost;
        }
    
        //TODO : Check if this is necessary
        if (bufferGovIndex != std::to_string(stackIndex))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initReduce_strict()
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
    
      sequence.emplace_back(Action::popStack());
    
    
      cost = [](const Config & config)
      {
        if (!config.isToken(config.getStack(0)))
          return 0;
    
        int cost = 0;
    
        auto stackIndex = config.getStack(0);
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
        for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initReduce_relaxed()
    {
      sequence.emplace_back(Action::popStack());
    
    
      cost = [](const Config & config)
      {
    
          return 0;
    
        int cost = 0;
    
    
        auto stackIndex = config.getStack(0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
    
        for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
    
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          ++cost;
    
        return cost;
      };
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initEOS(int bufferIndex)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRoot(bufferIndex));
      sequence.emplace_back(Action::updateIds(bufferIndex));
      sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::emptyStack());
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      cost = [bufferIndex](const Config & config)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
        if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
    
          return std::numeric_limits<int>::max();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        return 0;
    
    void Transition::initDeprel(std::string label)
    {
      sequence.emplace_back(Action::deprel(label));
    
      cost = [label](const Config & config)
      {
        return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
      };
    }