#include "MacaonTrain.hpp"
#include <filesystem>
#include "util.hpp"
#include "NeuralNetwork.hpp"

namespace po = boost::program_options;

po::options_description MacaonTrain::getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
    ("debug,d", "Print debuging infos on stderr")
    ("silent", "Don't print speed and progress")
    ("devScore", "Compute score on dev instead of loss (slower)")
    ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
      "Comma separated column names that describes the input/output format")
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
    ("batchSize", po::value<int>()->default_value(64),
      "Number of examples per batch")
    ("explorationThreshold", po::value<float>()->default_value(0.1),
      "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
    ("machine", po::value<std::string>()->default_value(""),
      "Reading machine file content")
    ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
      "Description of what should happen during training")
    ("loss", po::value<std::string>()->default_value("CrossEntropy"),
      "Loss function to use during training : CrossEntropy | bce | mse | hinge")
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map MacaonTrain::checkOptions(po::options_description & od)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
{
  Trainer::TrainStrategy ts;

  try
  {
    auto splited = util::split(s, ':');
    for (auto & ss : splited)
    {
      auto elements = util::split(ss, ',');

      int epoch = std::stoi(elements[0]);

      for (unsigned int i = 1; i < elements.size(); i++)
        ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
    }
  } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}

  return ts;
}

template 
<
  typename Optimizer = torch::optim::Adam,
  typename OptimizerOptions = torch::optim::AdamOptions
>
inline auto decay(Optimizer &optimizer, double rate) -> void
{
  for (auto &group : optimizer.param_groups())
  {
    for (auto &param : group.params())
    {
      if (!param.grad().defined())
        continue;

      auto &options = static_cast<OptimizerOptions &>(group.options());
      options.lr(options.lr() * (1.0 - rate));
    }
  }
}

int MacaonTrain::main()
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
  auto mcd = variables["mcd"].as<std::string>();
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();
  auto batchSize = variables["batchSize"].as<int>();
  bool debug = variables.count("debug") == 0 ? false : true;
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
  bool computeDevScore = variables.count("devScore") == 0 ? false : true;
  auto machineContent = variables["machine"].as<std::string>();
  auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
  auto lossFunction = variables["loss"].as<std::string>();
  auto explorationThreshold = variables["explorationThreshold"].as<float>();

  auto trainStrategy = parseTrainStrategy(trainStrategyStr);

  torch::globalContext().setBenchmarkCuDNN(true);

  if (!machineContent.empty())
  {
    std::FILE * file = fopen(machinePath.c_str(), "w");
    if (!file)
      util::error(fmt::format("can't open file '{}'\n", machinePath.c_str()));
    fmt::print(file, "{}", machineContent);
    std::fclose(file);
  }

  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());

  try
  {

  ReadingMachine machine(machinePath.string());

  BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
  BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);

  Trainer trainer(machine, batchSize, lossFunction);
  Decoder decoder(machine);

  if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
  {
    machine.loadDicts();
    machine.getClassifier()->getNN()->registerEmbeddings();
    machine.loadLastSaved();
    machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
  }

  float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();

  auto trainInfos = machinePath.parent_path() / "train.info";

  int currentEpoch = 0;

  if (std::filesystem::exists(trainInfos))
  {
    std::FILE * f = std::fopen(trainInfos.c_str(), "r");
    char buffer[1024];
    while (!std::feof(f))
    {
      if (buffer != std::fgets(buffer, 1024, f))
        break;
      bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
      float devScoreMean = std::stof(util::split(buffer, '\t').back());
      if (saved)
        bestDevScore = devScoreMean;
      currentEpoch++;
    }
    std::fclose(f);
  }

  auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
  if (std::filesystem::exists(trainInfos))
  {
    machine.getClassifier()->resetOptimizer();
    machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
  }

  for (; currentEpoch < nbEpoch; currentEpoch++)
  {
    bool saved = false;

    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
    {
      for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
        if (entry.is_regular_file())
          std::filesystem::remove(entry.path());

      if (!computeDevScore)
        for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
          if (entry.is_regular_file())
            std::filesystem::remove(entry.path());
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
    {
      machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
      trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
      if (!computeDevScore)
      {
        machine.setDictsState(Dict::State::Closed);
        trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
      }
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
    {
      if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
      {
        machine.resetClassifier();
        machine.trainMode(currentEpoch == 0);
        machine.getClassifier()->getNN()->registerEmbeddings();
        machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
        fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
      }

      machine.getClassifier()->resetOptimizer();
    }
    if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
    {
      saved = true;
    }

    trainer.makeDataLoader(modelPath/"examples/train");
    if (!computeDevScore)
      trainer.makeDevDataLoader(modelPath/"examples/dev");

    float loss = trainer.epoch(printAdvancement);
    if (debug)
      fmt::print(stderr, "Decoding dev :\n");
    std::vector<std::pair<float,std::string>> devScores;
    if (computeDevScore)
    {
      BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
      decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
      decoder.evaluate(devConfig, modelPath, devTsvFile);
      devScores = decoder.getF1Scores(machine.getPredicted());
    }
    else
    {
      float devLoss = trainer.evalOnDev(printAdvancement);
      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
    }

    std::string devScoresStr = "";
    float devScoreMean = 0;
    for (auto & score : devScores)
    {
      if (computeDevScore)
        devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      else
        devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, 100.0*score.first, computeDevScore ? "%" : "");
      devScoreMean += score.first;
    }
    if (!devScoresStr.empty())
      devScoresStr.pop_back();
    devScoreMean /= devScores.size();

    if (computeDevScore)
      saved = saved or devScoreMean >= bestDevScore;
    else
      saved = saved or devScoreMean <= bestDevScore;

    if (saved)
    {
      bestDevScore = devScoreMean;
      machine.saveBest();
    }
    machine.saveLast();
    machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
    if (printAdvancement)
      fmt::print(stderr, "\r{:80}\r", "");
    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : "");
    fmt::print(stderr, "{}\n", iterStr);
    std::FILE * f = std::fopen(trainInfos.c_str(), "a");
    fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
    std::fclose(f);
  }

  }
  catch(std::exception & e) {util::error(e);}

  return 0;
}

MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}