#include "Trainer.hpp"
#include "SubConfig.hpp"

LossFunction::LossFunction(std::string name)
{
  if (util::lower(name) == "crossentropy")
    fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
  else if (util::lower(name) == "bce")
    fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
  else if (util::lower(name) == "mse")
    fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
  else if (util::lower(name) == "hinge")
    fct = CustomHingeLoss();
  else
    util::myThrow(fmt::format("unknown loss function name '{}'", name));
}

torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold)
{
  auto index = fct.index();

  if (index == 0)
    return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0)));
  if (index == 1)
    return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
  if (index == 2)
    return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
  if (index == 3)
    return std::get<3>(fct)(torch::softmax(prediction, 1), gold);

  util::myThrow("loss is not defined");
  return torch::Tensor();
}

torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const
{
  auto index = fct.index();

  if (index == 0)
  {
    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
    gold[0] = goldIndexes.at(0);
    return gold;
  }
  if (index == 1 or index == 2 or index == 3)
  {
    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
    for (auto goldIndex : goldIndexes)
      gold[goldIndex] = 1;
    return gold;
  }

  util::myThrow("loss is not defined");
  return torch::Tensor();
}

Trainer::Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName) : machine(machine), batchSize(batchSize), lossFct(lossFunctionName)
{
}

void Trainer::makeDataLoader(std::filesystem::path dir)
{
  trainDataset.reset(new Dataset(dir));
  dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}

void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
  devDataset.reset(new Dataset(dir));
  devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}

void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
  SubConfig config(goldConfig, goldConfig.getNbLines());

  machine.trainMode(false);

  extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold);

  machine.saveDicts();
}

void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
  torch::AutoGradMode useGrad(false);

  int maxNbExamplesPerFile = 50000;
  std::map<std::string, Examples> examplesPerState;

  std::filesystem::create_directories(dir);

  config.addPredicted(machine.getPredicted());
  config.setStrategy(machine.getStrategyDefinition());
  config.setState(config.getStrategy().getInitialState());
  machine.getClassifier()->setState(config.getState());

  auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);

  if (std::filesystem::exists(currentEpochAllExtractedFile))
    return;

  fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");

  int totalNbExamples = 0;

  while (true)
  {
    if (debug)
      config.printForDebug(stderr);

    if (machine.hasSplitWordTransitionSet())
      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
    config.setAppliableTransitions(appliableTransitions);

    std::vector<std::vector<long>> context;

    try
    {
      context = machine.getClassifier()->getNN()->extractContext(config);
    } catch(std::exception & e)
    {
      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
    }

    Transition * transition = nullptr;

    auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
    Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()];
    int nbClasses = machine.getTransitionSet().size();
      
    if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
    {
      auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
      auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze();
  
      float bestScore = std::numeric_limits<float>::min();
      std::vector<int> candidates;

      for (unsigned int i = 0; i < prediction.size(0); i++)
      {
        float score = prediction[i].item<float>();
        if (score > bestScore and appliableTransitions[i])
          bestScore = score;
      }

      for (unsigned int i = 0; i < prediction.size(0); i++)
      {
        float score = prediction[i].item<float>();
        if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
          candidates.emplace_back(i);
      }

      transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]);
    }
    else
    {
      transition = goldTransition;
    }

    if (!transition or !goldTransition)
    {
      config.printForDebug(stderr);
      util::myThrow("No transition appliable !");
    }

    totalNbExamples += context.size();
    if (totalNbExamples >= (int)safetyNbExamplesMax)
      util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));

    std::vector<int> goldIndexes;
    for (auto & t : goldTransitions)
      goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t));

    examplesPerState[config.getState()].addContext(context);
    examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes);
    examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = config.getStrategy().getMovement(config, transition->getName());
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    machine.getClassifier()->setState(movement.first);
    config.moveWordIndexRelaxed(movement.second);

    if (config.needsUpdate())
      config.update();
  }

  for (auto & it : examplesPerState)
    it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);

  std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
  if (!f)
    util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
  std::fclose(f);

  fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
}

float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
{
  constexpr int printInterval = 50;
  int nbExamplesProcessed = 0;
  int totalNbExamplesProcessed = 0;
  float totalLoss = 0.0;
  float lossSoFar = 0.0;

  torch::AutoGradMode useGrad(train);
  machine.trainMode(train);

  auto pastTime = std::chrono::high_resolution_clock::now();

  for (auto & batch : *loader)
  {
    if (train)
      machine.getClassifier()->getOptimizer().zero_grad();

    auto data = std::get<0>(batch);
    auto labels = std::get<1>(batch);
    auto state = std::get<2>(batch);

    machine.getClassifier()->setState(state);

    auto prediction = machine.getClassifier()->getNN()(data);
    if (prediction.dim() == 1)
      prediction = prediction.unsqueeze(0);

    auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
    float lossAsFloat = 0.0;
    try
    {
      lossAsFloat = loss.item<float>();
    } catch(std::exception & e) {util::myThrow(e.what());}

    totalLoss += lossAsFloat;
    lossSoFar += lossAsFloat;

    if (train)
    {
      loss.backward();
      machine.getClassifier()->getOptimizer().step();
    }

    totalNbExamplesProcessed += labels.size(0);

    if (printAdvancement)
    {
      nbExamplesProcessed += labels.size(0);

      if (nbExamplesProcessed >= printInterval)
      {
        auto actualTime = std::chrono::high_resolution_clock::now();
        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        pastTime = actualTime;
        auto speed = (int)(nbExamplesProcessed/seconds);
        auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
        auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
        if (train)
          fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
        else
          fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
        lossSoFar = 0;
        nbExamplesProcessed = 0;
      }
    }
  }

  return totalLoss / nbExamples;
}

float Trainer::epoch(bool printAdvancement)
{
  return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}

float Trainer::evalOnDev(bool printAdvancement)
{
  return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
}

void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
  if (currentExampleIndex-lastSavedIndex < (int)threshold)
    return;
  if (contexts.empty())
    return;

  int nbClasses = classes[0].size(0);

  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
  auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
  torch::save(tensorToSave, dir/filename);
  lastSavedIndex = currentExampleIndex;
  contexts.clear();
  classes.clear();
}

void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
  for (auto & element : context)
    contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());

  currentExampleIndex += context.size();
}

void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes)
{
  auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);

  while (classes.size() < contexts.size())
    classes.emplace_back(gold);
}

Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
  if (s == "ExtractGold")
    return TrainAction::ExtractGold;
  if (s == "ExtractDynamic")
    return TrainAction::ExtractDynamic;
  if (s == "DeleteExamples")
    return TrainAction::DeleteExamples;
  if (s == "ResetOptimizer")
    return TrainAction::ResetOptimizer;
  if (s == "ResetParameters")
    return TrainAction::ResetParameters;
  if (s == "Save")
    return TrainAction::Save;

  util::myThrow(fmt::format("unknown TrainAction '{}'", s));

  return TrainAction::ExtractGold;
}