Skip to content
Snippets Groups Projects
Transition.cpp 7.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "Transition.hpp"
    
    #include <regex>
    
    Transition::Transition(const std::string & name)
    {
      this->name = name;
    
      std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
    
      std::regex shiftRegex("SHIFT");
      std::regex reduceRegex("REDUCE");
      std::regex leftRegex("LEFT (.+)");
      std::regex rightRegex("RIGHT (.+)");
    
    Franck Dary's avatar
    Franck Dary committed
      std::regex eosRegex("EOS");
    
      if (util::doIfNameMatch(writeRegex, name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
    
        return;
    
      if (util::doIfNameMatch(shiftRegex, name, [this](auto){initShift();}))
        return;
      if (util::doIfNameMatch(reduceRegex, name, [this](auto){initReduce();}))
        return;
      if (util::doIfNameMatch(leftRegex, name, [this](auto sm){initLeft(sm[1]);}))
        return;
      if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);}))
        return;
    
    Franck Dary's avatar
    Franck Dary committed
      if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();}))
        return;
    
      throw std::invalid_argument("no match");
    
      } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", name, e.what()));}
    
    }
    
    
    void Transition::apply(Config & config)
    {
      for (Action & action : sequence)
        action.apply(config, action);
    }
    
    bool Transition::appliable(const Config & config) const
    {
      for (const Action & action : sequence)
        if (!action.appliable(config, action))
          return false;
    
      return true;
    }
    
    int Transition::getCost(const Config & config) const
    {
      return cost(config);
    }
    
    
    const std::string & Transition::getName() const
    {
      return name;
    }
    
    
    void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
    {
      auto objectValue = Action::str2object(object);
      int indexValue = std::stoi(index);
    
      sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
    
      cost = [colName, objectValue, indexValue, value](const Config & config)
      {
        int lineIndex = 0;
        if (objectValue == Action::Object::Buffer)
          lineIndex = config.getWordIndex() + indexValue;
        else
          lineIndex = config.getStack(indexValue);
    
        if (config.getConst(colName, lineIndex, 0) == value)
          return 0;
    
    void Transition::initShift()
    
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
      cost = [](const Config & config)
      {
    
        if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          return std::numeric_limits<int>::max();
    
    
        if (!config.isToken(config.getWordIndex()))
    
        auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
    
    
        int cost = 0;
        for (int i = 0; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto stackIndex = config.getStack(i);
    
          auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
          if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
    
            ++cost;
        }
    
        return cost;
      };
    }
    
    void Transition::initLeft(std::string label)
    {
      sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
      sequence.emplace_back(Action::popStack());
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
    
        auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
    
    
        for (int i = wordIndex+1; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
    
            ++cost;
        }
    
        //TODO : Check if this is necessary
    
        if (stackGovIndex != std::to_string(wordIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, stackIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initRight(std::string label)
    {
      sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
      sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
      sequence.emplace_back(Action::pushWordIndexOnStack());
    
      cost = [label](const Config & config)
      {
        auto stackIndex = config.getStack(0);
        auto wordIndex = config.getWordIndex();
        if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
          return std::numeric_limits<int>::max();
    
        int cost = 0;
    
    
        auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
    
    
        for (int i = wordIndex; config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex))
    
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        for (int i = 1; config.hasStack(i); ++i)
        {
    
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
    
          auto otherStackIndex = config.getStack(i);
    
          auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
    
          if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
    
            ++cost;
        }
    
        //TODO : Check if this is necessary
    
        if (bufferGovIndex != std::to_string(stackIndex))
    
          ++cost;
    
        if (label != config.getConst(Config::deprelColName, wordIndex, 0))
          ++cost;
    
        return cost;
      };
    }
    
    void Transition::initReduce()
    {
      sequence.emplace_back(Action::popStack());
    
      cost = [](const Config & config)
      {
    
        if (!config.has(0, config.getStack(0), 0))
          return 0;
    
        if (!config.isToken(config.getStack(0)))
    
          return 0;
    
        int cost = 0;
    
    
        auto stackIndex = config.getStack(0);
        auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0);
    
    
        for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
        {
          if (!config.isToken(i))
            continue;
    
    
          auto otherGovIndex = config.getConst(Config::headColName, i, 0);
    
          if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
    
            ++cost;
    
          if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
            break;
        }
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
          ++cost;
    
        return cost;
      };
    
    Franck Dary's avatar
    Franck Dary committed
    void Transition::initEOS()
    {
      sequence.emplace_back(Action::setRoot());
    
      sequence.emplace_back(Action::updateIds());
    
    Franck Dary's avatar
    Franck Dary committed
      sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
      sequence.emplace_back(Action::emptyStack());
    
      cost = [](const Config & config)
      {
        if (!config.has(0, config.getStack(0), 0))
          return std::numeric_limits<int>::max();
    
        if (!config.isToken(config.getStack(0)))
          return std::numeric_limits<int>::max();
    
        if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
    
          return std::numeric_limits<int>::max();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    
        --cost;
        for (int i = 0; config.hasStack(i); ++i)
        {
          if (!config.has(0, config.getStack(i), 0))
            continue;
    
          auto otherStackIndex = config.getStack(i);
    
          auto otherStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex);
    
          if (util::isEmpty(otherStackGovPred))
    
    Franck Dary's avatar
    Franck Dary committed
    
        return cost;
      };
    }