from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_raw_results import ModelRawResults
from bolsonaro.models.model_factory import ModelFactory
from bolsonaro.visualization.plotter import Plotter

import argparse
import pathlib
from dotenv import find_dotenv, load_dotenv
import os


if __name__ == "__main__":
    # get environment variables in .env
    load_dotenv(find_dotenv('.env.example'))

    DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
    DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
    DEFAULT_EXPERIMENT_IDS = None

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
    parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
    parser.add_argument('--experiment_ids', nargs='+', type=int, default=DEFAULT_EXPERIMENT_IDS, help='Compute the results of the specified experiment id(s)')
    args = parser.parse_args()

    pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)

    experiments_ids = [str(experiment_id) for experiment_id in args.experiment_ids] \
        if args.experiment_ids is not None \
        else os.listdir(args.models_dir)

    if experiments_ids is None or len(experiments_ids) == 0:
        raise ValueError("No experiment id was found or specified.")

    for experiment_id in experiments_ids:
        experiment_id_path = args.models_dir + os.sep + experiment_id
        pathlib.Path(args.results_dir + os.sep + experiment_id).mkdir(parents=True, exist_ok=True)
        experiment_seed_root_path = experiment_id_path + os.sep + 'seeds'

        experiment_train_scores = dict()
        experiment_dev_scores = dict()
        experiment_test_scores = dict()
        experiment_score_metrics = list()

        for seed in os.listdir(experiment_seed_root_path):
            experiment_seed_path = experiment_seed_root_path + os.sep + seed
            dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
            dataset = DatasetLoader.load(dataset_parameters)
            extracted_forest_size_root_path = experiment_seed_path + os.sep + 'extracted_forest_size'

            experiment_train_scores[seed] = list()
            experiment_dev_scores[seed] = list()
            experiment_test_scores[seed] = list()

            extracted_forest_sizes = os.listdir(extracted_forest_size_root_path)
            for extracted_forest_size in extracted_forest_sizes:
                extracted_forest_size_path = extracted_forest_size_root_path + os.sep + extracted_forest_size
                model_raw_results = ModelRawResults.load(extracted_forest_size_path)
                model = ModelFactory.load(dataset.task, extracted_forest_size_path, experiment_id, model_raw_results)
                experiment_train_scores[seed].append(model_raw_results.train_score)
                experiment_dev_scores[seed].append(model_raw_results.dev_score)
                experiment_test_scores[seed].append(model_raw_results.test_score)
                experiment_score_metrics.append(model_raw_results.score_metric)

        if len(set(experiment_score_metrics)) > 1:
            raise ValueError("The metrics used to compute the dev score aren't the same everytime")

        Plotter.plot_losses(
            file_path=args.results_dir + os.sep + experiment_id + os.sep + 'losses.png',
            all_experiment_scores=[experiment_train_scores, experiment_dev_scores, experiment_test_scores],
            x_value=extracted_forest_sizes,
            xlabel='Number of trees extracted',
            ylabel=experiment_score_metrics[0],
            all_labels=['train', 'dev', 'test'],
            title='Loss values of the trained model'
        )