/// \file programOptionsTemplates.cpp
/// \author Franck Dary
/// @version 1.0
/// @date 2019-05-27

#include <cstdio>
#include <cstdlib>
#include <boost/program_options.hpp>
#include "programOptionsTemplates.hpp"
#include "BD.hpp"
#include "Config.hpp"
#include "TransitionMachine.hpp"
#include "Trainer.hpp"
#include "ProgramParameters.hpp"

namespace po = boost::program_options;

/// @brief Get the list of mandatory and optional macaon_train program arguments.
///
/// @return The lists.
po::options_description getTrainOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("expName", po::value<std::string>()->required(),
      "Name of this experiment")
    ("templateName", po::value<std::string>()->required(),
      "Name of the template folder")
    ("tm", po::value<std::string>()->required(),
      "File describing the Tape Machine we will train")
    ("bd", po::value<std::string>()->required(),
      "BD file that describes the multi-tapes buffer")
    ("mcd", po::value<std::string>()->required(),
      "MCD file that describes the input")
    ("train,T", po::value<std::string>()->required(),
      "Training corpus formated according to the MCD");

  po::options_description opt("Optional");
  opt.add_options()
    ("help,h", "Produce this help message")
    ("debug,d", "Print infos on stderr")
    ("alwaysSave", "Save the model at every iteration")
    ("noNeuralNetwork", "Don't use any neural network, useful to speed up debug")
    ("randomDebug", "Print infos on stderr with a probability of randomDebugProbability")
    ("randomDebugProbability", po::value<float>()->default_value(0.001),
      "Probability that debug infos will be printed")
    ("printEntropy", "Print mean entropy and standard deviation accross sequences")
    ("dicts", po::value<std::string>()->default_value(""),
      "The .dict file describing all the dictionaries to be used in the experiement. By default the filename specified in the .tm file will be used")
    ("featureModels", po::value<std::string>()->default_value(""),
      "For each classifier, specify what .fm (feature model) file to use. By default the filename specified in the .cla file will be used. Example : --featureModel Parser=parser.fm,Tagger=tagger.fm")
    ("optimizer", po::value<std::string>()->default_value("amsgrad"),
      "The learning algorithm to use : amsgrad | adam | sgd")
    ("loss", po::value<std::string>()->default_value("neglogsoftmax"),
      "The loss function to use : neglogsoftmax")
    ("dev", po::value<std::string>()->default_value(""),
      "Development corpus formated according to the MCD")
    ("lang", po::value<std::string>()->default_value("fr"),
      "Language you are working with")
    ("nbiter,n", po::value<int>()->default_value(5),
      "Number of training epochs (iterations)")
    ("iterationSize", po::value<int>()->default_value(-1),
      "The number of examples for each iteration. -1 means the whole training set")
    ("lr", po::value<float>()->default_value(0.001),
      "Learning rate of the optimizer")
    ("seed,s", po::value<int>()->default_value(100),
      "The random seed that will initialize RNG")
    ("nbTrain", po::value<int>()->default_value(0),
      "The number of models that will be trained, with only the random seed changing")
    ("duplicates", po::value<bool>()->default_value(true),
      "Remove identical training examples")
    ("showFeatureRepresentation", po::value<int>()->default_value(0),
      "For each state of the Config, show its feature representation")
    ("interactive", po::value<bool>()->default_value(true),
      "Is the shell interactive ? Display advancement informations")
    ("randomEmbeddings", po::value<bool>()->default_value(false),
      "When activated, the embeddings will be randomly initialized")
    ("randomParameters", po::value<bool>()->default_value(true),
      "When activated, the parameters will be randomly initialized")
    ("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
      "The name of the buffer's tape that contains the delimiter token for a sequence, or 0 not to use sequences")
    ("sequenceDelimiter", po::value<std::string>()->default_value("1"),
      "The value of the token that act as a delimiter for sequences")
    ("batchSize", po::value<int>()->default_value(50),
      "The size of each minibatch (in number of taining examples)")
    ("dictCapacity", po::value<int>()->default_value(50000),
      "The maximal size of each Dict (number of differents embeddings).")
    ("maxStackSize", po::value<int>()->default_value(200),
      "The maximal size of the stack (transition based parsing).")
    ("tapeToMask", po::value<std::string>()->default_value("FORM"),
      "The name of the Tape for which some of the elements will be masked.")
    ("maskRate", po::value<float>()->default_value(0.0),
      "The rate of elements of the Tape that will be masked.")
    ("printTime", "Print time on stderr.")
    ("featureExtraction", "Use macaon only as a feature extractor, print corpus to stdout.")
    ("devEvalOnGold", "If true, dev accuracy will be computed on gold configurations.")
    ("devLoss", po::value<bool>()->default_value(false),
      "Compute and print total loss on dev for every epoch.")
    ("shuffle", po::value<bool>()->default_value(true),
      "Shuffle examples after each iteration");

  po::options_description oracle("Oracle related options");
  oracle.add_options()
    ("epochd", po::value<int>()->default_value(3),
      "Number of the first epoch where the oracle will be dynamic")
    ("proba", po::value<float>()->default_value(0.9),
      "The probability that the dynamic oracle will chose the predicted action");

  po::options_description ams("Amsgrad family optimizers");
  ams.add_options()
    ("b1", po::value<float>()->default_value(0.9),
      "beta1 parameter for the Amsgtad or Adam optimizer")
    ("b2", po::value<float>()->default_value(0.999),
      "beta2 parameter for the Amsgtad or Adam optimizer")
    ("bias", po::value<float>()->default_value(1e-8),
      "bias parameter for the Amsgtad or Adam  or Adagrad optimizer");

  po::options_description ga("Genetic algorithm related options");
  ga.add_options()
    ("nbIndividuals", po::value<int>()->default_value(5),
      "Number of individuals that will be used to take decisions");

  desc.add(req).add(opt).add(oracle).add(ams).add(ga);

  return desc;
}

/// @brief Store the program arguments inside a variables_map
///
/// @param od The description of all the possible options.
/// @param argc The number of arguments given to this program.
/// @param argv The values of arguments given to this program.
///
/// @return The variables map
po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception& e)
  {
    std::cerr << "Error: " << e.what() << "\n";
    od.print(std::cerr);
    exit(1);
  }

  if (vm.count("help"))
  {
    std::cout << od << "\n";
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e)
  {
    std::cerr << "Error: " << e.what() << "\n";
    od.print(std::cerr);
    exit(1);
  }

  return vm;
}

/// @brief Set all the usefull paths relative to expPath
void updatePaths()
{
  const char * MACAON_DIR = std::getenv("MACAON_DIR");
  std::string slash = "/";
  ProgramParameters::langPath = MACAON_DIR + slash + ProgramParameters::lang + slash;
  ProgramParameters::expPath = ProgramParameters::langPath + "bin/" + ProgramParameters::expName + slash;
  ProgramParameters::templatePath = ProgramParameters::langPath + ProgramParameters::templateName + slash;
  ProgramParameters::tmFilename = ProgramParameters::expPath + ProgramParameters::tmName;
  ProgramParameters::bdFilename = ProgramParameters::expPath + ProgramParameters::bdName;
  ProgramParameters::mcdFilename = ProgramParameters::expPath + ProgramParameters::mcdName;
  if (ProgramParameters::trainName[0] == '/' || ProgramParameters::trainName[0] == '~')
    ProgramParameters::trainFilename = ProgramParameters::trainName;
  else
    ProgramParameters::trainFilename = ProgramParameters::expPath + ProgramParameters::trainName;
  if (ProgramParameters::devName[0] == '/' || ProgramParameters::devName[0] == '~')
    ProgramParameters::devFilename = ProgramParameters::devName;
  else
    ProgramParameters::devFilename = ProgramParameters::expPath + ProgramParameters::devName;
  ProgramParameters::newTemplatePath = ProgramParameters::langPath + "bin/" + ProgramParameters::baseExpName + slash;
}

/// @brief Create the folder containing the current experiment from the template frolder
void createExpPath()
{
  // If this is the case, the training should resume where it was stopped.
  // No need to recreate the expPath.
  if (ProgramParameters::newTemplatePath == ProgramParameters::expPath)
    return;

std::string decode = "\
#! /bin/bash\n\
\n\
if [ \"$#\" -lt 2 ]; then\n\
 echo \"Usage : $0 input mcd\"\n\
 exit\n\
fi\n\
\n\
INPUT=$1\n\
MCD=$2\n\
\n\
shift\n\
shift\n\
ARGS=\"\"\n\
for arg in \"$@\"\n\
do\n\
  ARGS=\"$ARGS $arg\"\n\
done\n\
\n\
macaon_decode --lang " + ProgramParameters::lang +  " --tm machine.tm --bd test.bd -I $INPUT --mcd $MCD --expName " + ProgramParameters::expName + "$ARGS";

  if (system(("rm -r " + ProgramParameters::expPath + " 2> /dev/null").c_str())){}
  if (system(("mkdir " + ProgramParameters::expPath).c_str())){}
  if (system(("cp -r " + ProgramParameters::newTemplatePath + "* " + ProgramParameters::expPath + ".").c_str())){}
  if (system(("echo \'" + decode + "\' > " + ProgramParameters::expPath + "decode.sh").c_str())){}
  if (system(("chmod +x " + ProgramParameters::expPath + "decode.sh").c_str())){}
  if (system(("ln -f -s " + ProgramParameters::expPath + "decode.sh " + ProgramParameters::langPath + "bin/maca_tm_" + ProgramParameters::expName).c_str())){}

  ProgramParameters::tapeSize = getNbLines(ProgramParameters::trainFilename);
  ProgramParameters::devTapeSize = ProgramParameters::devFilename.empty() ? 0 : getNbLines(ProgramParameters::devFilename);
  ProgramParameters::readSize = ProgramParameters::tapeSize;
}

/// @brief Train a model according to all the ProgramParameters
void launchTraining()
{
  TransitionMachine transitionMachine(true);

  BD trainBD(ProgramParameters::bdFilename, ProgramParameters::mcdFilename);
  Config trainConfig(trainBD, ProgramParameters::trainFilename);

  std::unique_ptr<BD> devBD;
  std::unique_ptr<Config> devConfig;

  std::unique_ptr<Trainer> trainer;

  if(ProgramParameters::devFilename.empty())
  {
    trainer.reset(new Trainer(transitionMachine, trainBD, trainConfig));
  }
  else
  {
    devBD.reset(new BD(ProgramParameters::bdFilename, ProgramParameters::mcdFilename));
    devConfig.reset(new Config(*devBD.get(), ProgramParameters::devFilename));
    trainer.reset(new Trainer(transitionMachine, trainBD, trainConfig, devBD.get(), devConfig.get()));
  }

  trainer->train();
}

void createTemplatePath()
{
  if (system(("rm -r " + ProgramParameters::newTemplatePath + " 2> /dev/null").c_str())){}
  if (system(("mkdir " + ProgramParameters::newTemplatePath).c_str())){}
  if (system(("cp -r " + ProgramParameters::templatePath + "* " + ProgramParameters::newTemplatePath + ".").c_str())){}
}

void removeTemplatePath()
{
  if (system(("rm -r " + ProgramParameters::newTemplatePath + " 2> /dev/null").c_str())){}
}

void loadTrainProgramParameters(int argc, char * argv[])
{
  auto od = getTrainOptionsDescription();

  po::variables_map vm = checkOptions(od, argc, argv);

  ProgramParameters::expName = vm["expName"].as<std::string>();
  ProgramParameters::baseExpName = ProgramParameters::expName;
  ProgramParameters::templateName = vm["templateName"].as<std::string>();
  ProgramParameters::tmName = vm["tm"].as<std::string>();
  ProgramParameters::bdName = vm["bd"].as<std::string>();
  ProgramParameters::mcdName = vm["mcd"].as<std::string>();
  ProgramParameters::debug = vm.count("debug") == 0 ? false : true;
  ProgramParameters::alwaysSave = vm.count("alwaysSave") == 0 ? false : true;
  ProgramParameters::noNeuralNetwork = vm.count("noNeuralNetwork") == 0 ? false : true;
  ProgramParameters::randomDebug = vm.count("randomDebug") == 0 ? false : true;
  ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true;
  ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true;
  ProgramParameters::featureExtraction = vm.count("featureExtraction") == 0 ? false : true;
  ProgramParameters::devEvalOnGold = vm.count("devEvalOnGold") == 0 ? false : true;
  ProgramParameters::trainName = vm["train"].as<std::string>();
  ProgramParameters::devName = vm["dev"].as<std::string>();
  ProgramParameters::lang = vm["lang"].as<std::string>();
  ProgramParameters::nbIter = vm["nbiter"].as<int>();
  ProgramParameters::seed = vm["seed"].as<int>();
  ProgramParameters::batchSize = vm["batchSize"].as<int>();
  ProgramParameters::dictCapacity = vm["dictCapacity"].as<int>();
  ProgramParameters::maxStackSize = vm["maxStackSize"].as<int>();
  ProgramParameters::nbTrain = vm["nbTrain"].as<int>();
  ProgramParameters::removeDuplicates = vm["duplicates"].as<bool>();
  ProgramParameters::interactive = vm["interactive"].as<bool>();
  ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
  ProgramParameters::devLoss = vm["devLoss"].as<bool>();
  ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
  ProgramParameters::randomParameters = vm["randomParameters"].as<bool>();
  ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
  ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
  ProgramParameters::learningRate = vm["lr"].as<float>();
  ProgramParameters::randomDebugProbability = vm["randomDebugProbability"].as<float>();
  ProgramParameters::beta1 = vm["b1"].as<float>();
  ProgramParameters::beta2 = vm["b2"].as<float>();
  ProgramParameters::bias = vm["bias"].as<float>();
  ProgramParameters::nbIndividuals = vm["nbIndividuals"].as<int>();
  ProgramParameters::optimizer = vm["optimizer"].as<std::string>();
  ProgramParameters::dicts = vm["dicts"].as<std::string>();
  ProgramParameters::loss = vm["loss"].as<std::string>();
  ProgramParameters::dynamicEpoch = vm["epochd"].as<int>();
  ProgramParameters::dynamicProbability = vm["proba"].as<float>();
  ProgramParameters::tapeToMask = vm["tapeToMask"].as<std::string>();
  ProgramParameters::maskRate = vm["maskRate"].as<float>();
  ProgramParameters::showFeatureRepresentation = vm["showFeatureRepresentation"].as<int>();
  ProgramParameters::iterationSize = vm["iterationSize"].as<int>();
  std::string featureModels = vm["featureModels"].as<std::string>();
  if (!featureModels.empty())
  {
    auto byClassifiers = split(featureModels, ',');
    for (auto & classifier : byClassifiers)
    {
      auto parts = split(classifier, '=');
      if (parts.size() != 2)
      {
        fprintf(stderr, "ERROR (%s) : wrong format for argument of option featureModels. Aborting.\n", ERRINFO);
        exit(1);
      }
      ProgramParameters::featureModelByClassifier[parts[0]] = parts[1];
    }
  }
}