/*Copyright (c) 2019 Alexis Nasr && Franck Dary

 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:i

 The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.*/
#include "Decoder.hpp"
#include "util.hpp"
#include "Error.hpp"
#include "ActionBank.hpp"
#include "ProgramOutput.hpp"
#include <chrono>

Decoder::Decoder(TransitionMachine & tm, Config & config)
: tm(tm), config(config)
{
}

struct EndOfDecode : public std::exception
{
  const char * what() const throw()
  {
    return "End of Decode";
  }
};

struct NoMoreActions : public std::exception
{
  const char * what() const throw()
  {
    return "No More Actions";
  }
};

bool EOSisPredicted(Config & config)
{
  return config.hasTape(ProgramParameters::sequenceDelimiterTape);
}

void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, std::string & action, Errors & errors)
{
    if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
    {
      auto zeroCostActions = classifier->getZeroCostActions(config);
      if (zeroCostActions.empty())
      {
        fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
        config.printForDebug(stderr);
        for (auto & a : weightedActions)
        {
          fprintf(stderr, "%s : ", a.second.second.c_str());
          Oracle::explainCostOfAction(stderr, config, a.second.second);
        }
        exit(1);
      }
      std::string oAction = zeroCostActions[0];
      for (auto & s : zeroCostActions)
        if (action == s)
          oAction = s;
      int actionCost = classifier->getActionCost(config, action);
      int linkLengthPrediction = ActionBank::getLinkLength(config, action);
      int linkLengthGold = ActionBank::getLinkLength(config, oAction);
      errors.add({action, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
    }
}

void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff)
{
  if (ProgramParameters::interactive)
  {
    int totalSize = ProgramParameters::tapeSize;
    int steps = config.getHead();
    if (ProgramParameters::rawInput)
    {
      totalSize = config.rawInput.size();
      steps = config.rawInputHeadIndex;
    }
    if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
      fprintf(stderr, "Decode : %.2f%%  speed : %s actions/s\r", 100.0*steps/totalSize, util::int2humanStr((int)currentSpeed).c_str());
  }
}

void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
{
    if (ProgramParameters::debug)
    {
      config.printForDebug(output);
      fprintf(output, "State : \'%s\'\n", tm.getCurrentState().c_str());

      Classifier::printWeightedActions(output, weightedActions);
      fprintf(output, "\n");
    }
}

std::pair<float,std::string> getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier, unsigned int index)
{
    if (weightedActions.empty())
    {
      fprintf(stderr, "ERROR (%s) : weightedActions is empty. Aborting.\n", ERRINFO);
      exit(1);
    }

    std::string & predictedAction = weightedActions[0].second.second;
    if (predictedAction.empty())
      throw EndOfDecode();
    float proba = weightedActions[0].second.first;
    Action * action = classifier->getAction(predictedAction);

    unsigned int nbValidActions = 0;
    for(unsigned int i = 0; i < weightedActions.size(); i++)
    {
      predictedAction = weightedActions[i].second.second;
      proba = weightedActions[i].second.first;
      action = classifier->getAction(predictedAction);

      if(weightedActions[i].first)
      {
        nbValidActions++;
        if (nbValidActions-1 == index)
          break;
      }
    }

    if(!action->appliable(config) || nbValidActions-1 != index)
    {
      // First case the analysis is finished but without an empty stack
      if (config.endOfTapes())
      {
        while (!config.stackEmpty())
          config.stackPop();
        throw EndOfDecode();
      }
      else if (nbValidActions-1 != index)
      {
        throw NoMoreActions();
      }
      else
      {
        fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
        exit(1);
      }
    }

    return {proba, predictedAction};
}

void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
{
  if (nbActions >= nbActionsCutoff)
  {
      auto actualTime = std::chrono::high_resolution_clock::now();

      double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
      currentSpeed = nbActions / seconds;

      pastTime = actualTime;

      nbActions = 0;
  }
}

void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
{
  if (!EOSisPredicted(config))
    return;

  if (config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] != ProgramParameters::sequenceDelimiter)
    justFlipped = false; 

  if ((config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
    justFlipped = true;

  if (justFlipped && (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis))
  {
    justFlipped = true;
    errors.newSequence();
    entropyAccumulator /= nbActionsInSequence;
    nbActionsInSequence = 0;
    if (ProgramParameters::printEntropy)
      fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
    entropyAccumulator = 0.0;
  }
}

void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
{
  float entropy = Classifier::computeEntropy(weightedActions);
  config.addToEntropyHistory(entropy);
  entropyAccumulator += entropy;
}

void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config)
{

    Action * action = tm.getCurrentClassifier()->getAction(actionName);
    TransitionMachine::Transition * transition = tm.getTransition(actionName);
    action->setInfos(tm.getCurrentClassifier()->name);
    config.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, 0);
    if (ProgramParameters::debug)
      fprintf(stderr, "Applying action=<%s>\n", action->name.c_str());
    action->apply(config);
    tm.takeTransition(transition);
}

void Decoder::decode()
{
  if (!ProgramParameters::rawInput)
    config.reset();
  config.fillTapesWithInput();

  if (ProgramParameters::beamSize > 1)
    decodeBeam();
  else
    decodeNoBeam();

  ProgramOutput::instance.print(stdout);
}

void Decoder::decodeNoBeam()
{
  float entropyAccumulator = 0.0;
  int nbActionsInSequence = 0;
  bool justFlipped = false;
  Errors errors;
  errors.newSequence();
  int nbActions = 0;
  int nbActionsCutoff = 200;
  float currentSpeed = 0.0;
  auto pastTime = std::chrono::high_resolution_clock::now();
  FILE * outputFile = stdout;

  config.setOutputFile(outputFile);

  if (ProgramParameters::debug)
    fprintf(stderr, "Begin decode\n");

  while (!config.isFinal())
  {
    if (ProgramParameters::debug)
      fprintf(stderr, "Config is not final\n");

    config.setCurrentStateName(tm.getCurrentClassifier()->name);
    Dict::currentClassifierName = tm.getCurrentClassifier()->name;

    auto weightedActions = tm.getCurrentClassifier()->weightActions(config);

    printAdvancement(config, currentSpeed, nbActionsCutoff);
    printDebugInfos(stderr, config, tm, weightedActions);

    std::pair<float,std::string> predictedAction;
    try {predictedAction = getClassifierAction(config, weightedActions, tm.getCurrentClassifier(), 0);}
    catch(EndOfDecode &) {break;}
    catch(NoMoreActions &) {continue;};

    checkAndRecordError(config, tm.getCurrentClassifier(), weightedActions, predictedAction.second, errors);

    if (ProgramParameters::showActions)
      nbActionsPerClassifier[tm.getCurrentClassifier()->name][predictedAction.second]++;

    applyActionAndTakeTransition(tm, predictedAction.second, config);

    nbActionsInSequence++;
    nbActions++;
    computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
    computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
    computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
  }

  if (ProgramParameters::debug)
    fprintf(stderr, "Config is final\n");

  if (ProgramParameters::errorAnalysis)
    errors.printStats();

  config.printTheRest(false);

  if (ProgramParameters::interactive)
    fprintf(stderr, "                                                     \n");

  if (ProgramParameters::showActions)
    printActionsPerClassifier(stderr);
}

void Decoder::printActionsPerClassifier(FILE * output)
{
  char buffer[100000];
  for (auto & classifier : nbActionsPerClassifier)
  {
    int total = 0;
    for (auto & action : classifier.second)
      total += action.second;

    for (int i = 0; i < 80; i++)
      fprintf(output, "-");
    fprintf(output, "\n%s :\n", classifier.first.c_str());
    std::vector< std::pair<float, std::string> > toPrint;
    for (auto & action : classifier.second)
    {
      sprintf(buffer, "\t%5.2f%%\t(%5d)\t%s\n", 100.0*action.second/total, action.second, action.first.c_str());
      toPrint.emplace_back(action.second, buffer);
    }

    std::sort(toPrint.begin(), toPrint.end(), [](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b){return a.first > b.first;});

    for (auto & it : toPrint)
      fprintf(output, "%s", it.second.c_str());
  }
}

struct BeamNode
{
  Classifier::WeightedActions weightedActions;
  TransitionMachine tm;
  Config config;
  std::string action;
  int nbActions;
  bool justFlipped;
  int lastFlippedIndex;

  double getEntropy()
  {
    if (nbActions == 0)
      return 0.0;

    return config.getEntropy() / nbActions;
  }
  void setFlipped()
  {
    if (EOSisPredicted(config) && config.getHead() > lastFlippedIndex && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter)
    {
      justFlipped = true; 
      lastFlippedIndex = config.getHead();
    }
    else
    {
      justFlipped = false;
    }
  }
  BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
  {
    justFlipped = false;
    nbActions = 0;
    lastFlippedIndex = 0;
    config.setOutputFile(nullptr);
    config.setEntropy(0.0);
  }
  BeamNode(BeamNode & other, const std::string & action, float proba) : tm(other.tm), config(other.config)
  {
    justFlipped = false;
    lastFlippedIndex = other.lastFlippedIndex;
    this->action = action;
    nbActions = other.nbActions + 1;
    config.setOutputFile(nullptr);
    config.addToEntropy(proba);
  }
};

void Decoder::decodeBeam()
{
  float entropyAccumulator = 0.0;
  int nbActionsInSequence = 0;
  bool justFlipped = false;
  Errors errors;
  errors.newSequence();
  int nbActions = 0;
  int nbActionsCutoff = 20;
  float currentSpeed = 0.0;
  auto pastTime = std::chrono::high_resolution_clock::now();

  FILE * outputFile = stdout;

  std::vector< std::shared_ptr<BeamNode> > beam;
  std::vector< std::shared_ptr<BeamNode> > otherBeam;
  std::vector< std::shared_ptr<BeamNode> > justFlippedBeam;
  beam.emplace_back(new BeamNode(tm, config));

  auto sortBeam = [&beam]()
  {
    std::sort(beam.begin(), beam.end(), [](const std::shared_ptr<BeamNode> & a, const std::shared_ptr<BeamNode> & b)
        {
          return a->getEntropy() > b->getEntropy();
        });
  };

  auto printBeam = [](std::vector< std::shared_ptr<BeamNode> > & beam)
  {
    for (auto & node : beam)
    {
      node->config.printForDebug(stderr);
      fprintf(stderr, "action : %s\n", node->action.c_str());
      fprintf(stderr, "nbActions : %d\n", node->nbActions);
      fprintf(stderr, "justFlipped : %s\n", node->justFlipped ? "true" : "false");
      fprintf(stderr, "lastFlippedIndex : %d\n", node->lastFlippedIndex);
      fprintf(stderr, "--------------------------------------------------------------------------------\n");
    }
  };

  bool endOfDecode = false;

  while (endOfDecode == false)
  {
    otherBeam.clear();

    bool mustContinue = false;
    for (auto & node : beam)
    {
      node->config.setCurrentStateName(node->tm.getCurrentClassifier()->name);
      Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;

      node->weightedActions = node->tm.getCurrentClassifier()->weightActions(node->config);

      printAdvancement(node->config, currentSpeed, nbActionsCutoff);

      unsigned int nbActionsMax = std::min(std::max(node->tm.getCurrentClassifier()->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
      for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
      {
        std::pair<float,std::string> predictedAction;
        try {predictedAction = getClassifierAction(node->config, node->weightedActions, node->tm.getCurrentClassifier(), actionIndex);}
        catch(EndOfDecode &) {mustContinue = true; break;}
        catch(NoMoreActions &) {break;};
        otherBeam.emplace_back(new BeamNode(*node.get(), predictedAction.second, predictedAction.first));
      }

      if (mustContinue)
        break;
    }

    if (ProgramParameters::debug)
    {
      fprintf(stderr, "################################# Beam before sort #################################\n");
      printBeam(otherBeam);
      fprintf(stderr, "####################################################################################\n");
    }

    beam = otherBeam;
    sortBeam();
    beam.resize(std::min((int)beam.size(), ProgramParameters::beamSize));
    if (beam.empty())
    {
      fprintf(stderr, "ERROR (%s) : beam is empty. Aborting.\n", ERRINFO);
      exit(1);
    }

    if (ProgramParameters::debug)
    {
      fprintf(stderr, "################################# Beam after sort #################################\n");
      printBeam(beam);
      fprintf(stderr, "###################################################################################\n");
    }

    for (auto & node : beam)
      node->config.setOutputFile(outputFile);

    for (auto & node : beam)
    {
      config.setCurrentStateName(node->tm.getCurrentState());
      Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;

      if (node.get() == beam.begin()->get())
        checkAndRecordError(node->config, node->tm.getCurrentClassifier(), node->weightedActions, node->action, errors);

      applyActionAndTakeTransition(node->tm, node->action, node->config);

      if (node.get() == beam.begin()->get())
      {
        nbActionsInSequence++;
        nbActions++;
        computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
        computeAndRecordEntropy(node->config, node->weightedActions, entropyAccumulator);
        computeAndPrintSequenceEntropy(node->config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
      }
      node->setFlipped();
    }

    for (unsigned int i = 0; i < beam.size(); i++)
    {
      if (beam[i]->justFlipped)
      {
        justFlippedBeam.push_back(beam[i]);
        beam[i] = beam[beam.size()-1];
        beam.pop_back();
        i--;
      }
    }

    if ((int)justFlippedBeam.size() >= ProgramParameters::beamSize || (justFlippedBeam.size() && mustContinue))
    {
      if (mustContinue || justFlippedBeam[0]->config.endOfTapes())
        endOfDecode = true;
      beam = justFlippedBeam;
      justFlippedBeam.clear();
      sortBeam();
      beam.resize(1);
      beam[0]->config.setEntropy(0.0);
      beam[0]->nbActions = 0;
    }

    if (!EOSisPredicted(beam[0]->config) && beam[0]->config.endOfTapes())
      endOfDecode = true;
  }

  if (ProgramParameters::errorAnalysis)
    errors.printStats();

  for (auto node : beam)
  {
    node->config.setOutputFile(outputFile);
    node->config.printTheRest(false);
  }

  if (ProgramParameters::interactive)
    fprintf(stderr, "                                                     \n");
}