#include "Transition.hpp"
#include <regex>

Transition::Transition(const std::string & name)
{
  this->name = name;

  std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
  std::regex shiftRegex("SHIFT");
  std::regex reduceRegex("REDUCE");
  std::regex leftRegex("LEFT (.+)");
  std::regex rightRegex("RIGHT (.+)");
  std::regex eosRegex("EOS");

  try
  {

  if (util::doIfNameMatch(writeRegex, name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
    return;
  if (util::doIfNameMatch(shiftRegex, name, [this](auto){initShift();}))
    return;
  if (util::doIfNameMatch(reduceRegex, name, [this](auto){initReduce();}))
    return;
  if (util::doIfNameMatch(leftRegex, name, [this](auto sm){initLeft(sm[1]);}))
    return;
  if (util::doIfNameMatch(rightRegex, name, [this](auto sm){initRight(sm[1]);}))
    return;
  if (util::doIfNameMatch(eosRegex, name, [this](auto){initEOS();}))
    return;

  throw std::invalid_argument("no match");

  } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", name, e.what()));}

}

void Transition::apply(Config & config)
{
  for (Action & action : sequence)
    action.apply(config, action);
}

bool Transition::appliable(const Config & config) const
{
  for (const Action & action : sequence)
    if (!action.appliable(config, action))
      return false;

  return true;
}

int Transition::getCost(const Config & config) const
{
  return cost(config);
}

const std::string & Transition::getName() const
{
  return name;
}

void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
  auto objectValue = Action::str2object(object);
  int indexValue = std::stoi(index);

  sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));

  cost = [colName, objectValue, indexValue, value](const Config & config)
  {
    int lineIndex = 0;
    if (objectValue == Action::Object::Buffer)
      lineIndex = config.getWordIndex() + indexValue;
    else
      lineIndex = config.getStack(indexValue);

    if (config.getConst(colName, lineIndex, 0) == value)
      return 0;

    return 1;
  };
}

void Transition::initShift()
{
  sequence.emplace_back(Action::pushWordIndexOnStack());

  cost = [](const Config & config)
  {
    if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
      return std::numeric_limits<int>::max();

    if (!config.isToken(config.getWordIndex()))
      return 0;

    auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);

    int cost = 0;
    for (int i = 0; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto stackIndex = config.getStack(i);
      auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);

      if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
        ++cost;
    }

    return cost;
  };
}

void Transition::initLeft(std::string label)
{
  sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
  sequence.emplace_back(Action::popStack());

  cost = [label](const Config & config)
  {
    auto stackIndex = config.getStack(0);
    auto wordIndex = config.getWordIndex();
    if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
      return std::numeric_limits<int>::max();

    int cost = 0;

    auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);

    for (int i = wordIndex+1; config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
        ++cost;
    }

    //TODO : Check if this is necessary
    if (stackGovIndex != std::to_string(wordIndex))
      ++cost;

    if (label != config.getConst(Config::deprelColName, stackIndex, 0))
      ++cost;

    return cost;
  };
}

void Transition::initRight(std::string label)
{
  sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
  sequence.emplace_back(Action::pushWordIndexOnStack());

  cost = [label](const Config & config)
  {
    auto stackIndex = config.getStack(0);
    auto wordIndex = config.getWordIndex();
    if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
      return std::numeric_limits<int>::max();

    int cost = 0;

    auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);

    for (int i = wordIndex; config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex))
        ++cost;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;
    }

    for (int i = 1; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto otherStackIndex = config.getStack(i);
      auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);

      if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
        ++cost;
    }

    //TODO : Check if this is necessary
    if (bufferGovIndex != std::to_string(stackIndex))
      ++cost;

    if (label != config.getConst(Config::deprelColName, wordIndex, 0))
      ++cost;

    return cost;
  };
}

void Transition::initReduce()
{
  sequence.emplace_back(Action::popStack());

  cost = [](const Config & config)
  {
    if (!config.has(0, config.getStack(0), 0))
      return 0;

    if (!config.isToken(config.getStack(0)))
      return 0;

    int cost = 0;

    auto stackIndex = config.getStack(0);
    auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0);

    for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
        ++cost;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;
    }

    if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
      ++cost;

    return cost;
  };
}

void Transition::initEOS()
{
  sequence.emplace_back(Action::setRoot());
  sequence.emplace_back(Action::updateIds());
  sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
  sequence.emplace_back(Action::emptyStack());

  cost = [](const Config & config)
  {
    if (!config.has(0, config.getStack(0), 0))
      return std::numeric_limits<int>::max();

    if (!config.isToken(config.getStack(0)))
      return std::numeric_limits<int>::max();

    if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
      return std::numeric_limits<int>::max();

    int cost = 0;

    --cost;
    for (int i = 0; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto otherStackIndex = config.getStack(i);
      auto otherStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex);

      if (util::isEmpty(otherStackGovPred))
        ++cost;
    }

    return cost;
  };
}