from bolsonaro.models.model_raw_results import ModelRawResults
from bolsonaro.visualization.plotter import Plotter
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory

import argparse
import pathlib
from dotenv import find_dotenv, load_dotenv
import os


def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_dir, experiment_id):
    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
    # Create recursively the tree results/{experiment_id}
    pathlib.Path(results_dir + os.sep + str(experiment_id)).mkdir(parents=True, exist_ok=True)
    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds

    """
    Dictionaries to temporarly store the scalar results with the following structure:
    {seed_1: [score_1, ..., score_m], ... seed_n: [score_1, ..., score_k]}
    """
    experiment_train_scores = dict()
    experiment_dev_scores = dict()
    experiment_test_scores = dict()
    all_extracted_forest_sizes = list()

    # Used to check if all losses were computed using the same metric (it should be the case)
    experiment_score_metrics = list()

    # For each seed results stored in models/{experiment_id}/seeds
    seeds = os.listdir(experiment_seed_root_path)
    seeds.sort(key=int)
    for seed in seeds:
        experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
        extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size

        # {{seed}:[]}
        experiment_train_scores[seed] = list()
        experiment_dev_scores[seed] = list()
        experiment_test_scores[seed] = list()

        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
        extracted_forest_sizes.sort(key=int)
        all_extracted_forest_sizes.append(list(map(int, extracted_forest_sizes)))
        for extracted_forest_size in extracted_forest_sizes:
            # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
            extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
            # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
            model_raw_results = ModelRawResults.load(extracted_forest_size_path)
            # Save the scores
            experiment_train_scores[seed].append(model_raw_results.train_score)
            experiment_dev_scores[seed].append(model_raw_results.dev_score)
            experiment_test_scores[seed].append(model_raw_results.test_score)
            # Save the metric
            experiment_score_metrics.append(model_raw_results.score_metric)

    # Sanity checks
    if len(set(experiment_score_metrics)) > 1:
        raise ValueError("The metrics used to compute the scores aren't the sames across seeds.")
    if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1:
        raise ValueError("The extracted forest sizes aren't the sames across seeds.")

    return experiment_train_scores, experiment_dev_scores, experiment_test_scores, all_extracted_forest_sizes[0]

def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number):
    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
    # Create recursively the tree results/{experiment_id}
    pathlib.Path(results_dir + os.sep + str(experiment_id)).mkdir(parents=True, exist_ok=True)
    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds

    """
    Dictionaries to temporarly store the scalar results with the following structure:
    {seed_1: [score_1, ..., score_m], ... seed_n: [score_1, ..., score_k]}
    """
    experiment_train_scores = dict()
    experiment_dev_scores = dict()
    experiment_test_scores = dict()

    # Used to check if all losses were computed using the same metric (it should be the case)
    experiment_score_metrics = list()

    # For each seed results stored in models/{experiment_id}/seeds
    seeds = os.listdir(experiment_seed_root_path)
    seeds.sort(key=int)
    for seed in seeds:
        experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
        forest_size_root_path = experiment_seed_path + os.sep + 'forest_size' # models/{experiment_id}/seeds/{seed}/forest_size

        # {{seed}:[]}
        experiment_train_scores[seed] = list()
        experiment_dev_scores[seed] = list()
        experiment_test_scores[seed] = list()

        forest_size = os.listdir(forest_size_root_path)[0]
        # models/{experiment_id}/seeds/{seed}/forest_size/{forest_size}
        forest_size_path = forest_size_root_path + os.sep + forest_size
        # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
        model_raw_results = ModelRawResults.load(forest_size_path)
        for _ in range(extracted_forest_sizes_number):
            # Save the scores
            experiment_train_scores[seed].append(model_raw_results.train_score)
            experiment_dev_scores[seed].append(model_raw_results.dev_score)
            experiment_test_scores[seed].append(model_raw_results.test_score)
            # Save the metric
            experiment_score_metrics.append(model_raw_results.score_metric)

    if len(set(experiment_score_metrics)) > 1:
        raise ValueError("The metrics used to compute the scores aren't the same everytime")

    return experiment_train_scores, experiment_dev_scores, experiment_test_scores

if __name__ == "__main__":
    # get environment variables in .env
    load_dotenv(find_dotenv('.env'))

    DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
    DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 4].')
    parser.add_argument('--experiment_ids', nargs='+', type=int, required=True, help='Compute the results of the specified experiment id(s).' + \
        'stage=1: {{base_with_params}} {{random_with_params}} {{omp_with_params}} {{base_wo_params}} {{random_wo_params}} {{omp_wo_params}}')
    parser.add_argument('--dataset_name', nargs='?', type=str, required=True, help='Specify the dataset name. TODO: read it from models dir directly.')
    parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
    parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
    args = parser.parse_args()

    if args.stage not in list(range(1, 5)):
        raise ValueError('stage must be a supported stage id (i.e. [1, 4]).')

    logger = LoggerFactory.create(LOG_PATH, os.path.basename(__file__))

    # Create recursively the results dir tree
    pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)

    if args.stage == 1:
        extracted_forest_sizes_number = 5 # TODO: hardcoded
        if len(args.experiment_ids) != 6:
            raise ValueError('In the case of stage 1, the number of specified experiment ids must be 6.')

        # Experiments that used the best hyperparameters found for this dataset

        # base_with_params
        logger.info('Loading base_with_params experiment scores...')
        base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores = \
            extract_scores_across_seeds_and_forest_size(args.models_dir, args.results_dir, args.experiment_ids[0],
            extracted_forest_sizes_number)
        # random_with_params
        logger.info('Loading random_with_params experiment scores...')
        random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, \
            with_params_extracted_forest_sizes = extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
        # omp_with_params
        logger.info('Loading omp_with_params experiment scores...')
        omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _ = \
            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[2])

        # Experiments that didn't use the best hyperparameters found for this dataset

        # base_wo_params
        logger.info('Loading base_wo_params experiment scores...')
        base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores = \
            extract_scores_across_seeds_and_forest_size(args.models_dir, args.results_dir, args.experiment_ids[3],
            extracted_forest_sizes_number)
        # random_wo_params
        logger.info('Loading random_wo_params experiment scores...')
        random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores, \
            wo_params_extracted_forest_sizes = extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4])
        # base_wo_params
        logger.info('Loading base_wo_params experiment scores...')
        omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores, _ = \
            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[5])

        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage1')
        pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

        Plotter.plot_losses(
            file_path=output_path + os.sep + 'losses_with_params.png',
            all_experiment_scores=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores,
                random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores,
                omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores],
            x_value=with_params_extracted_forest_sizes,
            xlabel='Number of trees extracted',
            ylabel='MSE', # TODO: hardcoded
            all_labels=['base_with_params_train', 'base_with_params_dev', 'base_with_params_test',
                'random_with_params_train', 'random_with_params_dev', 'random_with_params_test',
                'omp_with_params_train', 'omp_with_params_dev', 'omp_with_params_test'],
            title='Loss values of {} using the best hyperparams'.format(args.dataset_name)
        )
        Plotter.plot_losses(
            file_path=output_path + os.sep + 'losses_wo_params.png',
            all_experiment_scores=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores,
                random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores,
                omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores],
            x_value=wo_params_extracted_forest_sizes,
            xlabel='Number of trees extracted',
            ylabel='MSE', # TODO: hardcoded
            all_labels=['base_wo_params_train', 'base_wo_params_dev', 'base_wo_params_test',
                'random_wo_params_train', 'random_wo_params_dev', 'random_wo_params_test',
                'omp_wo_params_train', 'omp_wo_params_dev', 'omp_wo_params_test'],
            title='Loss values of {} without using the best hyperparams'.format(args.dataset_name)
        )
    else:
        raise ValueError('This stage number is not supported yet, but it will be!')

    """
    TODO:
    For each dataset:
    Stage 1) A figure for the selection of the best base forest model hyperparameters (best vs default/random hyperparams)
    Stage 2) A figure for the selection of the best combination of normalization: D normalization vs weights normalization (4 combinations)
    Stage 3) A figure for the selection of the most relevant subsets combination: train,dev vs train+dev,train+dev vs train,train+dev
    Stage 4) A figure to finally compare the perf of our approach using the previous selected
        parameters vs the baseline vs other papers using different extracted forest size
        (percentage of the tree size found previously in best hyperparams search) on the abscissa.

    IMPORTANT: Compare experiments that used the same seeds among them (except for stage 1).
    """