#include "Trainer.hpp"
#include "util.hpp"

Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config)
: tm(tm), trainBD(bd), trainConfig(config)
{
  this->devBD = nullptr;
  this->devConfig = nullptr;
}

Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig)
{
}

void Trainer::computeScoreOnDev()
{
  if (!devConfig)
    return;

  tm.reset();
  devConfig->reset();

  if (ProgramParameters::debug)
    fprintf(stderr, "Computing score on dev set\n");

  int nbActionsInSequence = 0;
  float entropyAccumulator = 0.0;
  bool justFlipped = false;
  int nbActions = 0;
  int nbActionsCutoff = 2*ProgramParameters::batchSize;
  float currentSpeed = 0.0;
  auto pastTime = std::chrono::high_resolution_clock::now();
  std::vector<float> entropies;

  while (!devConfig->isFinal())
  {
    devConfig->setCurrentStateName(tm.getCurrentState());
    Dict::currentClassifierName = tm.getCurrentClassifier()->name;
    tm.getCurrentClassifier()->initClassifier(*devConfig);

    if(!tm.getCurrentClassifier()->needsTrain())
    {
      int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig);
      std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
      Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
      TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
      action->setInfos(transition->headMvt, tm.getCurrentState());

      action->apply(*devConfig);
      tm.takeTransition(transition);
    }
    else
    {
      // Print current iter advancement in percentage
      if (ProgramParameters::interactive)
      {
        int totalSize = ProgramParameters::devTapeSize;
        int steps = devConfig->getHead();
        if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
        {
          fprintf(stderr, "                                                      \r");
          fprintf(stderr, "Eval on dev : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
        }
      }

      auto weightedActions = tm.getCurrentClassifier()->weightActions(*devConfig);
      std::string pAction = "";

      for (auto & it : weightedActions)
        if (it.first)
        {
          pAction = it.second.second;
          break;
        }

      bool pActionIsZeroCost = tm.getCurrentClassifier()->getActionCost(*devConfig, pAction) == 0;

      TI.addDevExample(tm.getCurrentClassifier()->name);
      if (pActionIsZeroCost)
        TI.addDevSuccess(tm.getCurrentClassifier()->name);

      std::string actionName = pAction;
      Action * action = tm.getCurrentClassifier()->getAction(actionName);

      if (ProgramParameters::debug)
      {
        devConfig->printForDebug(stderr);
        fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str());
      }

      TransitionMachine::Transition * transition = tm.getTransition(actionName);
      action->setInfos(transition->headMvt, tm.getCurrentState());
      action->apply(*devConfig);
      tm.takeTransition(transition);

      nbActions++;

      if (nbActions >= nbActionsCutoff)
      {
        auto actualTime = std::chrono::high_resolution_clock::now();

        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        currentSpeed = nbActions / seconds;

        pastTime = actualTime;

        nbActions = 0;
      }

      float entropy = Classifier::computeEntropy(weightedActions);
      devConfig->addToEntropyHistory(entropy);

      if (ProgramParameters::printEntropy)
      {
        nbActionsInSequence++;

        entropyAccumulator += entropy;

        if (devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] != ProgramParameters::sequenceDelimiter)
        justFlipped = false;

        if ((devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
        {
          justFlipped = true;
          entropyAccumulator /= nbActionsInSequence;
          nbActionsInSequence = 0;
          entropies.emplace_back(entropyAccumulator);
          entropyAccumulator = 0.0;
        }
      }

    }
  }

  if (ProgramParameters::debug)
    fprintf(stderr, "Dev Config is final\n");

  TI.computeDevScores();

  if (ProgramParameters::debug)
    fprintf(stderr, "End of %s\n", __func__);
}

void Trainer::train()
{
  Dict::createFiles(ProgramParameters::expPath, "");

  fprintf(stderr, "%sTraining of \'%s\' :\n", 
    ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "",
    tm.name.c_str());

  auto resetAndShuffle = [this]()
  {
    tm.reset();
    trainConfig.reset();

    if(ProgramParameters::shuffleExamples)
      trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);

    TI.resetCounters();
  };

  int nbSteps = 0;
  int nbActions = 0;
  int nbActionsCutoff = 2*ProgramParameters::batchSize;
  float currentSpeed = 0.0;
  auto pastTime = std::chrono::high_resolution_clock::now();
  while (TI.getEpoch() <= ProgramParameters::nbIter)
  {
    resetAndShuffle();
    while (!trainConfig.isFinal())
    {
      trainConfig.setCurrentStateName(tm.getCurrentState());
      Dict::currentClassifierName = tm.getCurrentClassifier()->name;
      tm.getCurrentClassifier()->initClassifier(trainConfig);

      if(!tm.getCurrentClassifier()->needsTrain())
      {
        int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig);
        std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
        if (ProgramParameters::debug)
        {
          trainConfig.printForDebug(stderr);
          fprintf(stderr, "action=<%s>\n", neededActionName.c_str());
        }

        Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
        TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
        action->setInfos(transition->headMvt, tm.getCurrentState());

        action->apply(trainConfig);
        tm.takeTransition(transition);
      }
      else
      {
        if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name))
        {
          TI.setTopologyPrinted(tm.getCurrentClassifier()->name);
          tm.getCurrentClassifier()->printTopology(stderr);
        }

        // Print current iter advancement in percentage
        if (ProgramParameters::interactive)
        {
          int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize;
          int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps;
          if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)
          {
            fprintf(stderr, "                                                      \r");
            fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
          }
        }

        auto weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
        std::string pAction = "";
        std::string oAction = "";

        bool pActionIsZeroCost = false;

        for (auto & it : weightedActions)
          if (it.first)
          {
            if (pAction == "")
              pAction = it.second.second;

            if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0)
            {
              oAction = it.second.second;
              break;
            }
          }

        if (pAction == oAction)
          pActionIsZeroCost = true;

        if (oAction.empty())
          oAction = tm.getCurrentClassifier()->getDefaultAction();

        if (oAction.empty())
        {
          if (trainConfig.endOfTapes())
          {
            while (!trainConfig.stackEmpty())
              trainConfig.stackPop();
            break;
          }

          fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
          fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str());
          trainConfig.printForDebug(stderr);
          tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig);
          exit(1);
        }

        tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));

        TI.addTrainExample(tm.getCurrentClassifier()->name);
        if (pActionIsZeroCost)
          TI.addTrainSuccess(tm.getCurrentClassifier()->name);

        int k = ProgramParameters::dynamicEpoch;

        std::string actionName = "";

        if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
        {
          actionName = pAction;
        }
        else
        {
          if (pActionIsZeroCost)
            actionName = pAction;
          else
            actionName = oAction;
        }

        if (ProgramParameters::debug)
        {
          trainConfig.printForDebug(stderr);
          fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
        }

        Action * action = tm.getCurrentClassifier()->getAction(actionName);
        TransitionMachine::Transition * transition = tm.getTransition(actionName);
        action->setInfos(transition->headMvt, tm.getCurrentState());

        action->apply(trainConfig);
        tm.takeTransition(transition);

        nbActions++;

        if (nbActions >= nbActionsCutoff)
        {
          auto actualTime = std::chrono::high_resolution_clock::now();

          double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
          currentSpeed = nbActions / seconds;

          pastTime = actualTime;

          nbActions = 0;
        }

        float entropy = Classifier::computeEntropy(weightedActions);
        trainConfig.addToEntropyHistory(entropy);
      }

      nbSteps++;
      if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize)
      {
        printScoresAndSave(stderr);
        nbSteps = 0;
        TI.nextEpoch();

        if (TI.getEpoch() > ProgramParameters::nbIter)
          break;
      }
    }

    if (ProgramParameters::debug)
      fprintf(stderr, "Config is final\n");

    if (ProgramParameters::iterationSize == -1)
    {
      printScoresAndSave(stderr);
      nbSteps = 0;
      TI.nextEpoch();

      if (TI.getEpoch() > ProgramParameters::nbIter)
        break;
    }

    if (ProgramParameters::debug)
      fprintf(stderr, "End of epoch\n");
  }
}

void Trainer::printScoresAndSave(FILE * output)
{
  TI.computeTrainScores();
  computeScoreOnDev();
  TI.computeMustSaves();

  auto classifiers = tm.getClassifiers();
  for (auto * cla : classifiers)
    if (TI.mustSave(cla->name))
    {
      if (ProgramParameters::debug)
        fprintf(stderr, "Saving %s...", cla->name.c_str());
      cla->save(ProgramParameters::expPath + cla->name + ".model");
      Dict::saveDicts(ProgramParameters::expPath, cla->name);
      if (ProgramParameters::debug)
        fprintf(stderr, "Done !\n");
    }

  TI.printScores(output);

   if (ProgramParameters::debug)
    fprintf(stderr, "End of %s\n", __func__); 
}