Skip to content
Snippets Groups Projects
Transition.cpp 17 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "Transition.hpp"
    
    #include <regex>
    
    Transition::Transition(const std::string & name)
    {
    
      std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
      {
        {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
          [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
        {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
          [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
    
        {std::regex("eager_SHIFT"),
          [this](auto){initEagerShift();}},
        {std::regex("standard_SHIFT"),
          [this](auto){initStandardShift();}},
    
        {std::regex("REDUCE_strict"),
          [this](auto){initReduce_strict();}},
        {std::regex("REDUCE_relaxed"),
          [this](auto){initReduce_relaxed();}},
    
        {std::regex("eager_LEFT_rel (.+)"),
          [this](auto sm){(initEagerLeft_rel(sm[1]));}},
        {std::regex("eager_RIGHT_rel (.+)"),
          [this](auto sm){(initEagerRight_rel(sm[1]));}},
    
        {std::regex("standard_LEFT_rel (.+)"),
          [this](auto sm){(initStandardLeft_rel(sm[1]));}},
        {std::regex("standard_RIGHT_rel (.+)"),
          [this](auto sm){(initStandardRight_rel(sm[1]));}},
    
        {std::regex("eager_LEFT"),
          [this](auto){(initEagerLeft());}},
        {std::regex("eager_RIGHT"),
          [this](auto){(initEagerRight());}},
        {std::regex("deprel (.+)"),
          [this](auto sm){(initDeprel(sm[1]));}},
    
    Franck Dary's avatar
    Franck Dary committed
        {std::regex("EOS b\\.(.+)"),
          [this](auto sm){initEOS(std::stoi(sm[1]));}},
    
        {std::regex("NOTHING"),
          [this](auto){initNothing();}},
        {std::regex("IGNORECHAR"),
          [this](auto){initIgnoreChar();}},
        {std::regex("ENDWORD"),
          [this](auto){initEndWord();}},
        {std::regex("ADDCHARTOWORD"),
          [this](auto){initAddCharToWord();}},
    
    Franck Dary's avatar
    Franck Dary committed
        {std::regex("SPLIT (.+)"),
          [this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
    
        {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
          [this](auto sm)
          {
            std::vector<std::string> splitRes{sm[1]};
            auto splited = util::split(std::string(sm[2]), '@');
            for (auto & s : splited)
              splitRes.emplace_back(s);
            initSplitWord(splitRes);
          }},
      };
    
      if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
    
            {
              this->state = sm[2];
              this->name = sm[3];
            }))
        util::myThrow("doesn't match nameRegex");
    
      for (auto & it : inits)
        if (util::doIfNameMatch(it.first, this->name, it.second))
          return;
    
      } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
    
    void Transition::apply(Config & config)
    {
      for (Action & action : sequence)
        action.apply(config, action);
    }
    
    bool Transition::appliable(const Config & config) const
    {
    
      if (!state.empty() && state != config.getState())
        return false;
    
    
      for (const Action & action : sequence)
        if (!action.appliable(config, action))
          return false;
    
      return true;
    }
    
    int Transition::getCost(const Config & config) const
    {
      return cost(config);
    }
    
    
    const std::string & Transition::getName() const
    {
      return name;
    }
    
    
    void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      auto objectValue = Config::str2object(object);
    
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
    
        if (config.getConst(colName, lineIndex, 0) == value)
          return 0;
    
    void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      auto objectValue = Config::str2object(object);
    
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
    
    
        auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
    
        for (auto & part : gold)
          if (part == value)
            return 0;
    
        return 1;
      };
    }
    
    void Transition::initNothing()
    {
      cost = [](const Config &)
      {
        return 0;
      };
    }
    
    
    void Transition::initIgnoreChar()
    {
      sequence.emplace_back(Action::ignoreCurrentCharacter());
    
      cost = [](const Config &)
      {
        return 0;
      };
    }
    
    void Transition::initEndWord()
    {
      sequence.emplace_back(Action::endWord());
    
      cost = [](const Config & config)
      {
        if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
          return 0;
        return 1;
      };
    }
    
    void Transition::initAddCharToWord()
    {
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
    
      sequence.emplace_back(Action::addLinesIfNeeded(0));
      sequence.emplace_back(Action::addCurCharToCurWord());
      sequence.emplace_back(Action::moveCharacterIndex(1));
    
      cost = [](const Config & config)
      {
        if (!config.hasCharacter(config.getCharacterIndex()))
          return std::numeric_limits<int>::max();
    
        auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
        auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
    
        auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
    
        if (curWord.size() + letter.size() > goldWord.size())
          return 1;
    
        for (unsigned int i = 0; i < letter.size(); i++)
          if (goldWord[curWord.size()+i] != letter[i])
            return 1;
    
        return 0;
      };
    }
    
    void Transition::initSplitWord(std::vector<std::string> words)
    {
    
      auto consumedWord = util::splitAsUtf8(words[0]);
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0));
    
      sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
      sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
      for (unsigned int i = 0; i < words.size(); i++)
    
    Franck Dary's avatar
    Franck Dary committed
        sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
    
      sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
    
    
      cost = [words](const Config & config)
      {
    
        if (!config.isMultiword(config.getWordIndex()))
          return std::numeric_limits<int>::max();
    
        if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
          return std::numeric_limits<int>::max();
    
    
        int cost = 0;
        for (unsigned int i = 0; i < words.size(); i++)
          if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
            cost++;
    
        return cost;
      };
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initSplit(int index)
    {
      sequence.emplace_back(Action::split(index));
    
      cost = [index](const Config & config)
      {
        auto & transitions = config.getAppliableSplitTransitions();
    
        if (index < 0 or index >= (int)transitions.size())
          return std::numeric_limits<int>::max();
    
        return transitions[index]->getCost(config);
      };
    }
    
    
    void Transition::initEagerShift()
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
    
      cost = [](const Config & config)
      {
    
        if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
    
        if (!config.isToken(config.getWordIndex()))
    
    Franck Dary's avatar
    Franck Dary committed
        return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
    
    void Transition::initStandardShift()
    {
      sequence.emplace_back(Action::pushWordIndexOnStack());
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
      cost = [](const Config & config)
      {
        if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
        return 0;
      };
    }
    
    
    void Transition::initEagerLeft_rel(std::string label)
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::popStack(0));
    
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
    
        if (stackGovIndex != std::to_string(wordIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, stackIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initStandardLeft_rel(std::string label)
    {
      sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
      sequence.emplace_back(Action::popStack(1));
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(1);
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
        auto wordIndex = config.getStack(0);
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, stackIndex, config);
    
        if (stackGovIndex != std::to_string(wordIndex))
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, stackIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initEagerLeft()
    {
      sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::popStack(0));
    
    
      cost = [](const Config & config)
      {
        auto stackIndex = config.getStack(0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
    
    
        if (stackGovIndex != std::to_string(wordIndex))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initEagerRight_rel(std::string label)
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
    Franck Dary's avatar
    Franck Dary committed
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
    
        if (bufferGovIndex != std::to_string(stackIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, wordIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initStandardRight_rel(std::string label)
    {
      sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
      sequence.emplace_back(Action::popStack(0));
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(1);
        auto wordIndex = config.getStack(0);
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
        int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, wordIndex, config);
    
        if (bufferGovIndex != std::to_string(stackIndex))
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, wordIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initEagerRight()
    {
      sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
      sequence.emplace_back(Action::pushWordIndexOnStack());
      sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
    
      cost = [](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
    
    Franck Dary's avatar
    Franck Dary committed
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
    
    
        if (bufferGovIndex != std::to_string(stackIndex))
          ++cost;
    
        return cost;
      };
    }
    
    
    void Transition::initReduce_strict()
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::popStack(0));
    
      cost = [](const Config & config)
      {
        auto stackIndex = config.getStack(0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto wordIndex = config.getWordIndex();
    
    Franck Dary's avatar
    Franck Dary committed
        if (!config.isToken(stackIndex))
          return 0;
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
    
    Franck Dary's avatar
    Franck Dary committed
        if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
    
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initReduce_relaxed()
    {
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::popStack(0));
    
      cost = [](const Config & config)
      {
    
        auto stackIndex = config.getStack(0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto wordIndex = config.getWordIndex();
    
    Franck Dary's avatar
    Franck Dary committed
        if (!config.isToken(stackIndex))
          return 0;
    
    Franck Dary's avatar
    Franck Dary committed
        int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
    
    Franck Dary's avatar
    Franck Dary committed
        if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
    
          ++cost;
    
        return cost;
      };
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initEOS(int bufferIndex)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::setRoot(bufferIndex));
      sequence.emplace_back(Action::updateIds(bufferIndex));
      sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::emptyStack());
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      cost = [bufferIndex](const Config & config)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
    Franck Dary's avatar
    Franck Dary committed
        int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
        if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
    
          return std::numeric_limits<int>::max();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        return 0;
    
    void Transition::initDeprel(std::string label)
    {
      sequence.emplace_back(Action::deprel(label));
    
      cost = [label](const Config & config)
      {
        return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
      };
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
    {
      auto govIndex = config.getConst(Config::headColName, withIndex, 0);
    
    Franck Dary's avatar
    Franck Dary committed
      auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
    
    Franck Dary's avatar
    Franck Dary committed
    
      int nbLinkedWith = 0;
    
      for (int i = firstIndex; i <= lastIndex; ++i)
      {
        int index = i;
        if (object == Config::Object::Stack)
          index = config.getStack(i);
    
        if (!config.isToken(index))
          continue;
    
        auto otherGovIndex = config.getConst(Config::headColName, index, 0);
    
    Franck Dary's avatar
    Franck Dary committed
        auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
    
    Franck Dary's avatar
    Franck Dary committed
        if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
          ++nbLinkedWith;
        if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
    
    Franck Dary's avatar
    Franck Dary committed
          ++nbLinkedWith;
      }
    
      return nbLinkedWith;
    }
    
    int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
    {
      int firstIndex = baseIndex;
    
      for (int i = baseIndex; config.has(0, i, 0); --i)
      {
        if (!config.isToken(i))
          continue;
    
        if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
          break;
    
        firstIndex = i;
      }
    
      return firstIndex;
    }
    
    int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
    {
      int lastIndex = baseIndex;
    
      for (int i = baseIndex; config.has(0, i, 0); ++i)
      {
        if (!config.isToken(i))
          continue;
    
        lastIndex = i;
    
        if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
          break;
      }
    
      return lastIndex;
    }