#include "MacaonTrain.hpp"
#include <filesystem>
#include "util.hpp"
#include "NeuralNetwork.hpp"

namespace po = boost::program_options;

po::options_description MacaonTrain::getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("mcd", po::value<std::string>()->required(),
      "Multi Column Description file that describes the input format")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
    ("debug,d", "Print debuging infos on stderr")
    ("silent", "Don't print speed and progress")
    ("devScore", "Compute score on dev instead of loss (slower)")
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map MacaonTrain::checkOptions(po::options_description & od)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

void MacaonTrain::fillDicts(ReadingMachine & rm, const Config & config)
{
  static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};

  for (auto & col : interestingColumns)
    if (config.has(col,0,0))
      for (auto & it : rm.getDicts())
      {
        it.second.countOcc(true);
        for (unsigned int j = 0; j < config.getNbLines(); j++)
          for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
            it.second.getIndexOrInsert(config.getConst(col,j,k));
        it.second.countOcc(false);
      }
}

int MacaonTrain::main()
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
  auto mcdFile = variables["mcd"].as<std::string>();
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();
  bool debug = variables.count("debug") == 0 ? false : true;
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
  bool computeDevScore = variables.count("devScore") == 0 ? false : true;

  fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());

  try
  {

  ReadingMachine machine(machinePath.string());

  BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
  BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
  SubConfig config(goldConfig);

  fillDicts(machine, goldConfig);

  Trainer trainer(machine);
  trainer.createDataset(config, debug);
  if (!computeDevScore)
  {
    SubConfig devConfig(devGoldConfig);
    trainer.createDevDataset(devConfig, debug);
  }

  Decoder decoder(machine);

  float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();

  for (int i = 0; i < nbEpoch; i++)
  {
    float loss = trainer.epoch(printAdvancement);
    machine.getStrategy().reset();
    if (debug)
      fmt::print(stderr, "Decoding dev :\n");
    std::vector<std::pair<float,std::string>> devScores;
    if (computeDevScore)
    {
      auto devConfig = devGoldConfig;
      decoder.decode(devConfig, 1, debug, printAdvancement);
      machine.getStrategy().reset();
      decoder.evaluate(devConfig, modelPath, devTsvFile);
      devScores = decoder.getF1Scores(machine.getPredicted());
    }
    else
    {
      float devLoss = trainer.evalOnDev(printAdvancement);
      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
    }

    std::string devScoresStr = "";
    float devScoreMean = 0;
    for (auto & score : devScores)
    {
      if (computeDevScore)
        devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      else
        devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      devScoreMean += score.first;
    }
    if (!devScoresStr.empty())
      devScoresStr.pop_back();
    devScoreMean /= devScores.size();
    bool saved = devScoreMean > bestDevScore;
    if (!computeDevScore)
      saved = devScoreMean < bestDevScore;
    if (saved)
    {
      bestDevScore = devScoreMean;
      machine.save();
    }
    if (debug)
      fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
    else
      fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
  }

  }
  catch(std::exception & e) {util::error(e);}

  return 0;
}

MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
{
}