Skip to content
Snippets Groups Projects
Transition.cpp 8.23 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "Transition.hpp"
    
    #include <regex>
    
    Transition::Transition(const std::string & name)
    {
      this->name = name;
    
      std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
    
      std::regex shiftRegex("SHIFT");
      std::regex reduceRegex("REDUCE");
      std::regex leftRegex("LEFT (.+)");
      std::regex rightRegex("RIGHT (.+)");
    
    Franck Dary's avatar
    Franck Dary committed
      std::regex eosRegex("EOS");
    
      if (util::doIfNameMatch(writeRegex, name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
    
        return;
    
      if (util::doIfNameMatch(shiftRegex, name, [this](auto){initShift();}))
        return;
      if (util::doIfNameMatch(reduceRegex, name, [this](auto){initReduce();}))
        return;
      if (util::doIfNameMatch(leftRegex, name, [this](auto sm){initLeft(sm[1]);}))
        return;
      if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);}))
        return;
    
    Franck Dary's avatar
    Franck Dary committed
      if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();}))
        return;
    
      throw std::invalid_argument("no match");
    
      } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", name, e.what()));}
    
    }
    
    
    void Transition::apply(Config & config)
    {
      for (Action & action : sequence)
        action.apply(config, action);
    }
    
    bool Transition::appliable(const Config & config) const
    {
      for (const Action & action : sequence)
        if (!action.appliable(config, action))
          return false;
    
      return true;
    }
    
    int Transition::getCost(const Config & config) const
    {
      return cost(config);
    }
    
    
    const std::string & Transition::getName() const
    {
      return name;
    }
    
    
    void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
    {
      auto objectValue = Action::str2object(object);
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
        int lineIndex = 0;
        if (objectValue == Action::Object::Buffer)
          lineIndex = config.getWordIndex() + indexValue;
        else
          lineIndex = config.getStack(indexValue);
    
        if (config.getConst(colName, lineIndex, 0) == value)
          return 0;
    
    void Transition::initShift()
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
      cost = [](const Config & config)
      {
    
        if (!config.isToken(config.getWordIndex()))
    
          return 0;
    
        auto headGov = config.getConst(Config::headColName, config.getWordIndex(), 0);
        auto headId = config.getConst(Config::idColName, config.getWordIndex(), 0);
    
        int cost = 0;
        for (int i = 0; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto stackIndex = config.getStack(i);
          auto stackId = config.getConst(Config::idColName, stackIndex, 0);
          auto stackGov = config.getConst(Config::headColName, stackIndex, 0);
    
          if (stackGov == headId || headGov == stackId)
            ++cost;
        }
    
        return cost;
      };
    }
    
    void Transition::initLeft(std::string label)
    {
      sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
      sequence.emplace_back(Action::popStack());
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
        auto idOfStack = config.getConst(Config::idColName, stackIndex, 0);
        auto govIdOfStack = config.getConst(Config::headColName, stackIndex, 0);
    
        for (int i = wordIndex+1; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
    
          auto idOfOther = config.getConst(Config::idColName, i, 0);
          auto govIdOfOther = config.getConst(Config::headColName, i, 0);
    
          if (govIdOfStack == idOfOther || govIdOfOther == idOfStack)
            ++cost;
        }
    
        //TODO : Check if this is necessary
        if (govIdOfStack != config.getConst(Config::idColName, wordIndex, 0))
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, stackIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initRight(std::string label)
    {
      sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
        auto idOfBuffer = config.getConst(Config::idColName, wordIndex, 0);
        auto govIdOfBuffer = config.getConst(Config::headColName, wordIndex, 0);
    
        for (int i = wordIndex; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          auto idOfOther = config.getConst(Config::idColName, i, 0);
          auto govIdOfOther = config.getConst(Config::headColName, i, 0);
    
          if (govIdOfBuffer == idOfOther || govIdOfOther == idOfBuffer)
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        for (int i = 1; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto otherStackIndex = config.getStack(i);
          auto stackId = config.getConst(Config::idColName, otherStackIndex, 0);
          auto stackGov = config.getConst(Config::headColName, otherStackIndex, 0);
    
          if (stackGov == idOfBuffer || govIdOfBuffer == stackId)
            ++cost;
        }
    
        //TODO : Check if this is necessary
        if (govIdOfBuffer != config.getConst(Config::idColName, stackIndex, 0))
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, wordIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initReduce()
    {
      sequence.emplace_back(Action::popStack());
    
      cost = [](const Config & config)
      {
    
        if (!config.has(0, config.getStack(0), 0))
          return 0;
    
        if (!config.isToken(config.getStack(0)))
    
          return 0;
    
        int cost = 0;
    
        auto idOfStack = config.getConst(Config::idColName, config.getStack(0), 0);
        auto govIdOfStack = config.getConst(Config::headColName, config.getStack(0), 0);
    
        for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          auto idOfOther = config.getConst(Config::idColName, i, 0);
          auto govIdOfOther = config.getConst(Config::headColName, i, 0);
    
          if (govIdOfStack == idOfOther || govIdOfOther == idOfStack)
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          ++cost;
    
        return cost;
      };
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initEOS()
    {
      sequence.emplace_back(Action::setRoot());
    
      sequence.emplace_back(Action::updateIds());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
      sequence.emplace_back(Action::emptyStack());
    
      cost = [](const Config & config)
      {
        if (!config.has(0, config.getStack(0), 0))
          return std::numeric_limits<int>::max();
    
        if (!config.isToken(config.getStack(0)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
    
          cost += 100;
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        auto topStackIndex = config.getStack(0);
        auto topStackGov = config.getConst(Config::headColName, topStackIndex, 0);
        auto topStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, topStackIndex);
    
        --cost;
        for (int i = 0; config.hasStack(i); ++i)
        {
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
          auto otherStackIndex = config.getStack(i);
          auto stackId = config.getConst(Config::idColName, otherStackIndex, 0);
          auto stackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex);
    
          if (util::isEmpty(stackGovPred))
            ++cost;
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
        return cost;
      };
    }