#include "Transition.hpp"
#include <regex>

Transition::Transition(const std::string & name)
{
  std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
  {
    {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
      [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
    {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
      [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
    {std::regex("SHIFT"),
      [this](auto){initShift();}},
    {std::regex("REDUCE"),
      [this](auto){initReduce();}},
    {std::regex("LEFT (.+)"),
      [this](auto sm){(initLeft(sm[1]));}},
    {std::regex("RIGHT (.+)"),
      [this](auto sm){(initRight(sm[1]));}},
    {std::regex("EOS"),
      [this](auto){initEOS();}},
    {std::regex("NOTHING"),
      [this](auto){initNothing();}},
    {std::regex("IGNORECHAR"),
      [this](auto){initIgnoreChar();}},
    {std::regex("ENDWORD"),
      [this](auto){initEndWord();}},
    {std::regex("ADDCHARTOWORD"),
      [this](auto){initAddCharToWord();}},
    {std::regex("SPLIT (.+)"),
      [this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
    {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
      [this](auto sm)
      {
        std::vector<std::string> splitRes{sm[1]};
        auto splited = util::split(std::string(sm[2]), '@');
        for (auto & s : splited)
          splitRes.emplace_back(s);
        initSplitWord(splitRes);
      }},
  };

  try
  {
  if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
        {
          this->state = sm[2];
          this->name = sm[3];
        }))
    util::myThrow("doesn't match nameRegex");

  for (auto & it : inits)
    if (util::doIfNameMatch(it.first, this->name, it.second))
      return;

  throw std::invalid_argument("no match");

  } catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}

}

void Transition::apply(Config & config)
{
  for (Action & action : sequence)
    action.apply(config, action);
}

bool Transition::appliable(const Config & config) const
{
  if (!state.empty() && state != config.getState())
    return false;

  for (const Action & action : sequence)
    if (!action.appliable(config, action))
      return false;

  return true;
}

int Transition::getCost(const Config & config) const
{
  return cost(config);
}

const std::string & Transition::getName() const
{
  return name;
}

void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
  auto objectValue = Action::str2object(object);
  int indexValue = std::stoi(index);

  sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));

  cost = [colName, objectValue, indexValue, value](const Config & config)
  {
    int lineIndex = 0;
    if (objectValue == Action::Object::Buffer)
      lineIndex = config.getWordIndex() + indexValue;
    else
      lineIndex = config.getStack(indexValue);

    if (config.getConst(colName, lineIndex, 0) == value)
      return 0;

    return 1;
  };
}

void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
  auto objectValue = Action::str2object(object);
  int indexValue = std::stoi(index);

  sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));

  cost = [colName, objectValue, indexValue, value](const Config & config)
  {
    int lineIndex = 0;
    if (objectValue == Action::Object::Buffer)
      lineIndex = config.getWordIndex() + indexValue;
    else
      lineIndex = config.getStack(indexValue);

    auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');

    for (auto & part : gold)
      if (part == value)
        return 0;

    return 1;
  };
}

void Transition::initNothing()
{
  cost = [](const Config &)
  {
    return 0;
  };
}

void Transition::initIgnoreChar()
{
  sequence.emplace_back(Action::ignoreCurrentCharacter());

  cost = [](const Config &)
  {
    return 0;
  };
}

void Transition::initEndWord()
{
  sequence.emplace_back(Action::endWord());

  cost = [](const Config & config)
  {
    if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
      return 0;
    return 1;
  };
}

void Transition::initAddCharToWord()
{
  sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
  sequence.emplace_back(Action::addLinesIfNeeded(0));
  sequence.emplace_back(Action::addCurCharToCurWord());
  sequence.emplace_back(Action::moveCharacterIndex(1));

  cost = [](const Config & config)
  {
    if (!config.hasCharacter(config.getCharacterIndex()))
      return std::numeric_limits<int>::max();

    auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
    auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
    auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
    if (curWord.size() + letter.size() > goldWord.size())
      return 1;

    for (unsigned int i = 0; i < letter.size(); i++)
      if (goldWord[curWord.size()+i] != letter[i])
        return 1;

    return 0;
  };
}

void Transition::initSplitWord(std::vector<std::string> words)
{
  auto consumedWord = util::splitAsUtf8(words[0]);
  sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
  sequence.emplace_back(Action::assertIsEmpty("FORM"));
  sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
  sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
  for (unsigned int i = 0; i < words.size(); i++)
    sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i]));
  sequence.emplace_back(Action::setMultiwordIds(words.size()-1));

  cost = [words](const Config & config)
  {
    if (!config.isMultiword(config.getWordIndex()))
      return std::numeric_limits<int>::max();

    if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
      return std::numeric_limits<int>::max();

    int cost = 0;
    for (unsigned int i = 0; i < words.size(); i++)
      if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
        cost++;

    return cost;
  };
}

void Transition::initSplit(int index)
{
  sequence.emplace_back(Action::split(index));

  cost = [index](const Config & config)
  {
    auto & transitions = config.getAppliableSplitTransitions();

    if (index < 0 or index >= (int)transitions.size())
      return std::numeric_limits<int>::max();

    return transitions[index]->getCost(config);
  };
}

void Transition::initShift()
{
  sequence.emplace_back(Action::pushWordIndexOnStack());

  cost = [](const Config & config)
  {
    if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
      return std::numeric_limits<int>::max();

    if (!config.isToken(config.getWordIndex()))
      return 0;

    auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);

    int cost = 0;
    for (int i = 0; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto stackIndex = config.getStack(i);
      auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);

      if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
        ++cost;
    }

    return cost;
  };
}

void Transition::initLeft(std::string label)
{
  sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
  sequence.emplace_back(Action::popStack());

  cost = [label](const Config & config)
  {
    auto stackIndex = config.getStack(0);
    auto wordIndex = config.getWordIndex();
    if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
      return std::numeric_limits<int>::max();

    int cost = 0;

    auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);

    for (int i = wordIndex+1; config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
        ++cost;
    }

    //TODO : Check if this is necessary
    if (stackGovIndex != std::to_string(wordIndex))
      ++cost;

    if (label != config.getConst(Config::deprelColName, stackIndex, 0))
      ++cost;

    return cost;
  };
}

void Transition::initRight(std::string label)
{
  sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
  sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
  sequence.emplace_back(Action::pushWordIndexOnStack());

  cost = [label](const Config & config)
  {
    auto stackIndex = config.getStack(0);
    auto wordIndex = config.getWordIndex();
    if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
      return std::numeric_limits<int>::max();

    int cost = 0;

    auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);

    for (int i = wordIndex; config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (bufferGovIndex == std::to_string(i))
        ++cost;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;
    }

    for (int i = 1; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto otherStackIndex = config.getStack(i);
      auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);

      if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
        ++cost;
    }

    //TODO : Check if this is necessary
    if (bufferGovIndex != std::to_string(stackIndex))
      ++cost;

    if (label != config.getConst(Config::deprelColName, wordIndex, 0))
      ++cost;

    return cost;
  };
}

void Transition::initReduce()
{
  sequence.emplace_back(Action::popStack());

  cost = [](const Config & config)
  {
    if (!config.isToken(config.getStack(0)))
      return 0;

    int cost = 0;

    auto stackIndex = config.getStack(0);
    auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);

    for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
    {
      if (!config.isToken(i))
        continue;

      auto otherGovIndex = config.getConst(Config::headColName, i, 0);

      if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
        ++cost;

      if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
        break;
    }

    if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
      ++cost;

    return cost;
  };
}

void Transition::initEOS()
{
  sequence.emplace_back(Action::setRoot());
  sequence.emplace_back(Action::updateIds());
  sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
  sequence.emplace_back(Action::emptyStack());

  cost = [](const Config & config)
  {
    if (!config.has(0, config.getStack(0), 0))
      return std::numeric_limits<int>::max();

    if (!config.isToken(config.getStack(0)))
      return std::numeric_limits<int>::max();

    if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
      return std::numeric_limits<int>::max();

    int cost = 0;

    --cost;
    for (int i = 0; config.hasStack(i); ++i)
    {
      if (!config.has(0, config.getStack(i), 0))
        continue;

      auto otherStackIndex = config.getStack(i);
      auto otherStackGovPred = config.getAsFeature(Config::headColName, otherStackIndex);

      if (util::isEmpty(otherStackGovPred))
        ++cost;
    }

    return cost;
  };
}