#include "Trainer.hpp"
#include "util.hpp"

Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config)
: tm(tm), trainBD(bd), trainConfig(config)
{
  this->devBD = nullptr;
  this->devConfig = nullptr;

  nbSteps = 0;
  nbActions = 0;
  nbActionsCutoff = 2*ProgramParameters::batchSize;
  currentSpeed = 0.0;
  pastTime = std::chrono::high_resolution_clock::now();
}

Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig)
{
  nbSteps = 0;
  nbActions = 0;
  nbActionsCutoff = 2*ProgramParameters::batchSize;
  currentSpeed = 0.0;
  pastTime = std::chrono::high_resolution_clock::now();
}

void Trainer::setDebugValue()
{
  if (!ProgramParameters::randomDebug)
    return;

  ProgramParameters::debug = choiceWithProbability(ProgramParameters::randomDebugProbability);

  if (!ProgramParameters::debug)
    return;

  if (ProgramParameters::interactive)
    fprintf(stderr, "                            \r");

  fprintf(stderr, "\n");

  if (ProgramParameters::printTime)
    fprintf(stderr, "[%s] :\n", getTime().c_str());
}

void Trainer::computeScoreOnDev()
{
  if (!devConfig)
    return;

  tm.reset();
  devConfig->reset();

  if (ProgramParameters::debug)
    fprintf(stderr, "Computing score on dev set\n");

  int nbActionsInSequence = 0;
  float entropyAccumulator = 0.0;
  bool justFlipped = false;
  int nbActions = 0;
  int nbActionsCutoff = 2*ProgramParameters::batchSize;
  currentSpeed = 0.0;
  auto pastTime = std::chrono::high_resolution_clock::now();
  std::vector<float> entropies;

  while (!devConfig->isFinal())
  {
    setDebugValue();
    devConfig->setCurrentStateName(tm.getCurrentClassifier()->name);
    Dict::currentClassifierName = tm.getCurrentClassifier()->name;
    tm.getCurrentClassifier()->initClassifier(*devConfig);

    if(!tm.getCurrentClassifier()->needsTrain())
    {
      int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig);
      std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
      Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
      TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
      action->setInfos(tm.getCurrentClassifier()->name);

      action->apply(*devConfig);
      tm.takeTransition(transition);
    }
    else
    {
      // Print current iter advancement in percentage
      if (ProgramParameters::interactive) 
      {
        int totalSize = ProgramParameters::devTapeSize;
        int steps = devConfig->getHead();
        if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
        {
          fprintf(stderr, "                                                      \r");
          fprintf(stderr, "Eval on dev : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
        }
      }

      auto weightedActions = tm.getCurrentClassifier()->weightActions(*devConfig);
      std::string pAction = "";
      std::string oAction = "";

      for (auto & it : weightedActions)
        if (it.first)
        {
          if (pAction.empty())
            pAction = it.second.second;

           if (tm.getCurrentClassifier()->getActionCost(*devConfig, it.second.second) == 0)
           {
             oAction = it.second.second;
             break;
           }
        }

      if (ProgramParameters::devLoss)
      {
        float loss = tm.getCurrentClassifier()->getLoss(*devConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
        TI.addDevLoss(tm.getCurrentClassifier()->name, loss);
      }

      std::string actionName;

      if (ProgramParameters::devEvalOnGold)
        actionName = oAction;
      else
        actionName = pAction;

      Action * action = tm.getCurrentClassifier()->getAction(actionName);

      if (ProgramParameters::debug)
      {
        fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
        devConfig->printForDebug(stderr);
        fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str());
      }

      TransitionMachine::Transition * transition = tm.getTransition(actionName);
      action->setInfos(tm.getCurrentClassifier()->name);
      devConfig->addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(*devConfig, actionName));

      action->apply(*devConfig);
      tm.takeTransition(transition);

      nbActions++;

      if (nbActions >= nbActionsCutoff)
      {
        auto actualTime = std::chrono::high_resolution_clock::now();

        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        currentSpeed = nbActions / seconds;

        pastTime = actualTime;

        nbActions = 0;
      }

      float entropy = Classifier::computeEntropy(weightedActions);
      devConfig->addToEntropyHistory(entropy);

      if (ProgramParameters::printEntropy)
      {
        nbActionsInSequence++;

        entropyAccumulator += entropy;

        if (devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] != ProgramParameters::sequenceDelimiter)
        justFlipped = false;

        if ((devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
        {
          justFlipped = true;
          entropyAccumulator /= nbActionsInSequence;
          nbActionsInSequence = 0;
          entropies.emplace_back(entropyAccumulator);
          entropyAccumulator = 0.0;
        }
      }

    }
  }

  if (ProgramParameters::debug)
    fprintf(stderr, "Dev Config is final\n");

  TI.computeDevScores(*devConfig);

  if (ProgramParameters::debug)
    fprintf(stderr, "End of %s\n", __func__);
}

void Trainer::resetAndShuffle()
{
  tm.reset();
  trainConfig.reset();

  if(ProgramParameters::shuffleExamples)
    trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
}

void Trainer::doStepNoTrain()
{
  int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig);
  std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
  if (ProgramParameters::debug)
  {
    fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
    trainConfig.printForDebug(stderr);
    fprintf(stderr, "action=<%s>\n", neededActionName.c_str());
  }

  Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
  TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
  action->setInfos(tm.getCurrentClassifier()->name);
  trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, action->name, tm.getCurrentClassifier()->getActionCost(trainConfig, action->name));

  action->apply(trainConfig);
  tm.takeTransition(transition);
}

void Trainer::doStepTrain()
{
  if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name))
  {
    TI.setTopologyPrinted(tm.getCurrentClassifier()->name);
    tm.getCurrentClassifier()->printTopology(stderr);
  }

  // Print current iter advancement in percentage
  if (ProgramParameters::interactive)
  {
    int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize;
    int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps;
    if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)
    {
      fprintf(stderr, "                                                      \r");
      fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
    }
  }

  std::string pAction = "";
  std::string oAction = "";

  std::string actionName = "";
  float loss = 0.0;

  Classifier::WeightedActions weightedActions;
  if (tm.getCurrentClassifier()->name.rfind("Error_", 0) != 0)
  {
    if (!ProgramParameters::featureExtraction)
    {
      weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
  
      for (auto & it : weightedActions)
        if (it.first)
          if (pAction == "")
            pAction = it.second.second;
  
      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
    }
    else
    {
      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
    }
  
    if (oAction.empty())
      oAction = tm.getCurrentClassifier()->getDefaultAction();
  
    if (oAction.empty())
    {
      if (trainConfig.endOfTapes())
      {
        while (!trainConfig.stackEmpty())
          trainConfig.stackPop();
        throw EndOfIteration();
      }
  
      fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
      fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str());
      trainConfig.printForDebug(stderr);
      tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig);
      exit(1);
    }

    if (!ProgramParameters::featureExtraction)
      loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
  
    TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
  
    int k = ProgramParameters::dynamicEpoch;
  
    if (ProgramParameters::featureExtraction)
    {
      auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues();
      fprintf(stdout, "%s\t%s\t%s\n", tm.getCurrentClassifier()->getFeatureModel()->filename.c_str(), oAction.c_str(), features.c_str());
    }
  
    if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
    {
      actionName = pAction;
      TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
    }
    else
    {
      actionName = oAction;
      TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = oAction == pAction;
    }
  
    if (ProgramParameters::debug)
    {
      fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
      trainConfig.printForDebug(stderr);
      tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
      fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
    }

  }
  else
  {
    if (!ProgramParameters::featureExtraction)
    {
      weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
  
      for (auto & it : weightedActions)
        if (it.first)
        {
          pAction = it.second.second;
          break;
        }

      auto zeroCosts = tm.getCurrentClassifier()->getZeroCostActions(trainConfig);
      oAction = zeroCosts[rand() % zeroCosts.size()];
    }
    else
    {
      oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
    }
  
    if (oAction.empty())
      oAction = tm.getCurrentClassifier()->getDefaultAction();
  
    if (oAction.empty())
    {
      if (trainConfig.endOfTapes())
      {
        while (!trainConfig.stackEmpty())
          trainConfig.stackPop();
        throw EndOfIteration();
      }
    }

    actionName = pAction;
    if (TI.getEpoch() < ProgramParameters::dynamicEpoch)
      actionName = oAction;
    else if (actionName == "EPSILON")
      actionName = oAction;

    char buffer[1024];
    if (sscanf(trainConfig.getCurrentStateName().c_str(), "Error_%s", buffer) != 1)
    {
      fprintf(stderr, "ERROR (%s) : unexpected classifier name \'%s\'. Aborting.\n", ERRINFO, trainConfig.getCurrentStateName().c_str());
      exit(1);
    }
    std::string normalStateName(buffer);

    auto & normalHistory = trainConfig.getActionsHistory(normalStateName);

    // If a BACK just happened
    if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
    {
      auto & lastAction = trainConfig.lastUndoneAction;
      auto & newAction = normalHistory[normalHistory.size()-1];
      auto & lastActionName = lastAction.first;
      auto & newActionName = newAction.first;
      auto lastCost = lastAction.second;
      auto newCost = newAction.second;
      
      if (ProgramParameters::debug)
      {
        fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost);
      }

      if (TI.lastActionWasPredicted[normalStateName])
      {
        std::string updateInfos;

        if (newCost >= lastCost)
        {
          if (true)
          {
            loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], -(tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())+1));
          }
          else
          {
            int nbActions = tm.getCurrentClassifier()->getNbActions();
            int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
            float value = 1.0 / (nbActions-1);
            std::vector<float> goldOutput(nbActions, value);
            goldOutput[backIndex] = 0.0;

            loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
          }

          updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, bad decision";
        }
        else
        {
          if (true)
          {
            loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
          }
          else
          {
            int nbActions = tm.getCurrentClassifier()->getNbActions();
            int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
            std::vector<float> goldOutput(nbActions, 0.0);
            goldOutput[backIndex] = 1.0;

            loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
          }

          updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, good decision";
        }

        if (ProgramParameters::debug)
          fprintf(stderr, "Updating neural network \'%s\' : %s\n", tm.getCurrentClassifier()->name.c_str(), updateInfos.c_str());

        TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
      }

    }
  
    if (ProgramParameters::debug)
    {
      fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
      trainConfig.printForDebug(stderr);
      tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
      fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
    }

    if (actionName != "EPSILON")
    {
      pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig);
    }
  }

  Action * action = tm.getCurrentClassifier()->getAction(actionName);
  TransitionMachine::Transition * transition = tm.getTransition(actionName);
  action->setInfos(tm.getCurrentClassifier()->name);

  trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName));

  action->apply(trainConfig);
  tm.takeTransition(transition);

  nbActions++;

  if (nbActions >= nbActionsCutoff)
  {
    auto actualTime = std::chrono::high_resolution_clock::now();

    double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
    currentSpeed = nbActions / seconds;

    pastTime = actualTime;

    nbActions = 0;
  }

  float entropy = Classifier::computeEntropy(weightedActions);
  trainConfig.addToEntropyHistory(entropy);
}

void Trainer::endOfIteration()
{
  auto classifiers = tm.getClassifiers();
  for (auto * cla : classifiers)
    if (cla->needsTrain())
      cla->endOfIteration();
}

void Trainer::prepareNextEpoch()
{
  endOfIteration();

  printScoresAndSave(stderr);
  nbSteps = 0;
  TI.nextEpoch();

  endOfIteration();

  if (TI.getEpoch() > ProgramParameters::nbIter)
    throw EndOfTraining();
}

void Trainer::train()
{
  Dict::createFiles(ProgramParameters::expPath, "");

  fprintf(stderr, "%sTraining of \'%s\' :\n", 
    ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "",
    tm.name.c_str());

  while (TI.getEpoch() <= ProgramParameters::nbIter)
  {
    resetAndShuffle();
    while (!trainConfig.isFinal())
    {
      setDebugValue();
      trainConfig.setCurrentStateName(tm.getCurrentClassifier()->name);
      Dict::currentClassifierName = tm.getCurrentClassifier()->name;
      tm.getCurrentClassifier()->initClassifier(trainConfig);

      if(!tm.getCurrentClassifier()->needsTrain())
        doStepNoTrain();
      else
        try {doStepTrain();}
        catch (EndOfIteration &) {break;}

      nbSteps++;

      if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize)
        try {prepareNextEpoch();}
        catch (EndOfTraining &) {break;}
    }

    if (ProgramParameters::debug)
      fprintf(stderr, "Config is final\n");

    if (ProgramParameters::iterationSize == -1)
      try {prepareNextEpoch();}
      catch (EndOfTraining &) {break;}

    if (ProgramParameters::debug)
      fprintf(stderr, "End of epoch\n");
  }
}

void Trainer::printScoresAndSave(FILE * output)
{
  TI.computeTrainScores(trainConfig);
  computeScoreOnDev();
  TI.computeMustSaves();

  auto classifiers = tm.getClassifiers();
  for (auto * cla : classifiers)
    if (TI.mustSave(cla->name))
    {
      if (ProgramParameters::debug)
        fprintf(stderr, "Saving %s...", cla->name.c_str());
      cla->save(ProgramParameters::expPath + cla->name + ".model");
      Dict::saveDicts(ProgramParameters::expPath, cla->name);
      if (ProgramParameters::debug)
        fprintf(stderr, "Done !\n");
    }

  TI.printScores(output);

   if (ProgramParameters::debug)
    fprintf(stderr, "End of %s\n", __func__); 
}