#include "Trainer.hpp"
#include "SubConfig.hpp"
#include <execution>

Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
{
}

void Trainer::makeDataLoader(std::filesystem::path dir)
{
  trainDataset.reset(new Dataset(dir));
  dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}

void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
  devDataset.reset(new Dataset(dir));
  devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}

void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
{
  std::vector<SubConfig> configs;
  for (auto & goldConfig : goldConfigs)
    configs.emplace_back(goldConfig, goldConfig.getNbLines());

  machine.trainMode(false);

  extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold, memcheck);

  machine.saveDicts();
}

void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
{
  torch::AutoGradMode useGrad(false);

  int maxNbExamplesPerFile = 50000;
  std::unordered_map<std::string, Examples> examplesPerState;
  std::mutex examplesMutex;

  std::filesystem::create_directories(dir);

  auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);

  if (std::filesystem::exists(currentEpochAllExtractedFile))
    return;

  fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");

  std::atomic<int> totalNbExamples = 0;

  if (memcheck)
    fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());

  NeuralNetworkImpl::setDevice(torch::kCPU);
  machine.to(NeuralNetworkImpl::getDevice());
  std::for_each(std::execution::seq, configs.begin(), configs.end(),
    [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, memcheck, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
    {
      config.addPredicted(machine.getPredicted());
      config.setStrategy(machine.getStrategyDefinition());
      config.setState(config.getStrategy().getInitialState());

      while (true)
      {
        if (debug)
          config.printForDebug(stderr);

        if (machine.hasSplitWordTransitionSet())
          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));

        auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
        config.setAppliableTransitions(appliableTransitions);

        torch::Tensor context;

        try
        {
          context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
        } catch(std::exception & e)
        {
          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
        }

        Transition * transition = nullptr;

        auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);

        Transition * goldTransition = goldTransitions[0];
        if (config.getState() == "parser")
          goldTransitions[std::rand()%goldTransitions.size()];

        int nbClasses = machine.getTransitionSet(config.getState()).size();

        float bestScore = -std::numeric_limits<float>::max();

        float entropy = 0.0;
          
        if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
        {
          auto & classifier = *machine.getClassifier(config.getState());
          auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
          entropy  = NeuralNetworkImpl::entropy(prediction);
      
          std::vector<int> candidates;

          for (unsigned int i = 0; i < prediction.size(0); i++)
          {
            float score = prediction[i].item<float>();
            if (score > bestScore and appliableTransitions[i])
              bestScore = score;
          }

          for (unsigned int i = 0; i < prediction.size(0); i++)
          {
            float score = prediction[i].item<float>();
            if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
              candidates.emplace_back(i);
          }

          transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);

          for (auto & trans : goldTransitions)
            if (trans == transition)
              goldTransition = trans;
        }
        else
        {
          transition = goldTransition;
        }

        if (!transition or !goldTransition)
        {
          config.printForDebug(stderr);
          util::myThrow("No transition appliable !");
        }

        std::vector<long> goldIndexes;
        bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);

        if (machine.getClassifier(config.getState())->isRegression())
        {
          entropy = 0.0;
          auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
          auto splited = util::split(transition->getName(), ' ');
          if (splited.size() != 3 or splited[0] != "WRITESCORE")
            util::myThrow(errMessage);
          auto col = splited[2];
          splited = util::split(splited[1], '.');
          if (splited.size() != 2)
            util::myThrow(errMessage);
          auto object = Config::str2object(splited[0]);
          int index = std::stoi(splited[1]);

          float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
          goldIndexes.emplace_back(util::float2long(regressionTarget));
        }
        else
        {
          for (auto & t : goldTransitions)
            goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));

        }

        if (!exampleIsBanned)
        {
          totalNbExamples += 1;
          if (totalNbExamples >= (int)safetyNbExamplesMax)
            util::error(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));

          examplesMutex.lock();
          examplesPerState[config.getState()].addContext(context);
          examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
          examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
          examplesMutex.unlock();
        }

        config.setChosenActionScore(bestScore);

        transition->apply(config, entropy);
        config.addToHistory(transition->getName());

        auto movement = config.getStrategy().getMovement(config, transition->getName());
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
        if (movement == Strategy::endMovement)
          break;

        config.setState(movement.first);
        config.moveWordIndexRelaxed(movement.second);

        if (config.needsUpdate())
          config.update();

      } // End while true

    if (memcheck)
      fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
  }); // End for on configs

  for (auto & it : examplesPerState)
    it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);

  NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
  machine.to(NeuralNetworkImpl::getDevice());

  std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
  if (!f)
    util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
  std::fclose(f);

  if (memcheck)
    fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
  fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
}

float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
{
  constexpr int printInterval = 50;
  int nbExamplesProcessed = 0;
  int totalNbExamplesProcessed = 0;
  float totalLoss = 0.0;
  float lossSoFar = 0.0;

  torch::AutoGradMode useGrad(train);
  machine.trainMode(train);

  auto pastTime = std::chrono::high_resolution_clock::now();

  for (auto & batch : *loader)
  {
    auto data = std::get<0>(batch);
    auto labels = std::get<1>(batch);
    auto state = std::get<2>(batch);

    if (train)
      machine.getClassifier(state)->getOptimizer().zero_grad();

    auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
    if (prediction.dim() == 1)
      prediction = prediction.unsqueeze(0);

    if (machine.getClassifier(state)->isRegression())
    {
      labels = labels.to(torch::kFloat);
      labels /= util::float2longScale;
    }

    auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state);

    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter;
    float lossAsFloat = 0.0;
    try
    {
      lossAsFloat = loss.item<float>();
    } catch(std::exception & e) {util::myThrow(e.what());}

    totalLoss += lossAsFloat;
    lossSoFar += lossAsFloat;

    if (train)
    {
      loss.backward();
      machine.getClassifier(state)->getOptimizer().step();
    }

    totalNbExamplesProcessed += labels.size(0);

    if (printAdvancement)
    {
      nbExamplesProcessed += labels.size(0);

      if (nbExamplesProcessed >= printInterval)
      {
        auto actualTime = std::chrono::high_resolution_clock::now();
        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        pastTime = actualTime;
        auto speed = (int)(nbExamplesProcessed/seconds);
        auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
        auto statusStr = fmt::format(lossSoFar/nbExamplesProcessed < 10.0 ? "{:6.2f}% loss={:<7.3f} speed={:<6}ex/s": "{:6.2f}% loss={:<7.0f} speed={:<6}ex/s", progression, lossSoFar / nbExamplesProcessed, speed);
        if (train)
          fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
        else
          fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
        lossSoFar = 0;
        nbExamplesProcessed = 0;
      }
    }
  }

  return totalLoss / nbExamples;
}

float Trainer::epoch(bool printAdvancement)
{
  return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}

float Trainer::evalOnDev(bool printAdvancement)
{
  return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}

void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
  if (currentExampleIndex-lastSavedIndex < (int)threshold)
    return;
  if (contexts.empty())
    return;

  int nbClasses = classes[0].size(0);

  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
  auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
  torch::save(tensorToSave, dir/filename);
  lastSavedIndex = currentExampleIndex;
  contexts.clear();
  classes.clear();
}

void Trainer::Examples::addContext(torch::Tensor & context)
{
  contexts.emplace_back(context);

  currentExampleIndex += 1;
}

void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)
{
  auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);

  while (classes.size() < contexts.size())
    classes.emplace_back(gold);
}

Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
  if (s == "ExtractGold")
    return TrainAction::ExtractGold;
  if (s == "ExtractDynamic")
    return TrainAction::ExtractDynamic;
  if (s == "DeleteExamples")
    return TrainAction::DeleteExamples;
  if (s == "ResetOptimizer")
    return TrainAction::ResetOptimizer;
  if (s == "ResetParameters")
    return TrainAction::ResetParameters;
  if (s == "Save")
    return TrainAction::Save;

  util::myThrow(fmt::format("unknown TrainAction '{}'", s));

  return TrainAction::ExtractGold;
}

void Trainer::extractActionSequence(BaseConfig & config)
{
  config.addPredicted(machine.getPredicted());
  config.setStrategy(machine.getStrategyDefinition());
  config.setState(config.getStrategy().getInitialState());

  int curSeq = 0;
  int curSeqStartIndex = -1;
  int curInputIndex = 0;
  int curInputSeqSize = 0;
  int curOutputSeqSize = 0;
  int maxInputSeqSize = 0;
  int maxOutputSeqSize = 0;
  bool newSent = true;
  std::vector<std::string> transitionsIndexes;

  while (true)
  {
    if (config.hasCharacter(0))
      curInputIndex = config.getCharacterIndex();
    else
      curInputIndex = config.getWordIndex();

    if (curSeqStartIndex == -1 or newSent)
    {
      newSent = false;
      curSeqStartIndex = curInputIndex;
    }

    if (machine.hasSplitWordTransitionSet())
      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));

    auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
    config.setAppliableTransitions(appliableTransitions);

    auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true);

    Transition * transition = goldTransitions[0];

    if (machine.getClassifier(config.getState())->isRegression())
      util::myThrow("Regressions are not supported in extract action sequence mode");

    transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition)));
    maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++);
    curInputSeqSize = -curSeqStartIndex + curInputIndex;
    maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++);
    if (util::split(transition->getName(), ' ')[0] == "EOS")
      if (++curSeq % 3 == 0)
      {
        newSent = true;
        std::string curSeq = "";
        for (int i = curSeqStartIndex; i <= curInputIndex; i++)
          curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
        fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
        curOutputSeqSize = 0;
        curInputSeqSize = 0;
        transitionsIndexes.clear();
      }

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = config.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    config.moveWordIndexRelaxed(movement.second);
  }

  if (curSeqStartIndex != curInputIndex)
  {
    std::string curSeq = "";
    for (int i = curSeqStartIndex; i <= curInputIndex; i++)
      curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
    fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
    curOutputSeqSize = 0;
    curInputSeqSize = 0;
    curSeqStartIndex = curInputIndex;
  }

  fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize);
  fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize);
}