from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.trainer import Trainer
from bolsonaro.utils import resolve_experiment_id
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory

from dotenv import find_dotenv, load_dotenv
import argparse
import pathlib
import random
import os
from tqdm import tqdm


if __name__ == "__main__":
    # get environment variables in .env
    load_dotenv(find_dotenv('.env.example'))

    DEFAULT_DATASET_NAME = 'boston'
    DEFAULT_NORMALIZE_D = False
    DEFAULT_DATASET_NORMALIZER = None
    DEFAULT_FOREST_SIZE = 100
    DEFAULT_EXTRACTED_FOREST_SIZE = 10
    # the models will be stored in a directory structure like: models/{experiment_id}/seeds/{seed_nb}/extracted_forest_size/{nb_extracted_trees}
    DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
    DEFAULT_DEV_SIZE = 0.2
    DEFAULT_TEST_SIZE = 0.2
    DEFAULT_RANDOM_SEED_NUMBER = 1
    DEFAULT_TRAIN_ON_SUBSET = 'train'
    DEFAULT_DISABLE_PROGRESS = False

    begin_random_seed_range = 1
    end_random_seed_range = 2000

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset_name', nargs='?', type=str, default=DEFAULT_DATASET_NAME, help='Specify the dataset. Regression: boston, diabetes, linnerud, california_housing. Classification: iris, digits, wine, breast_cancer, olivetti_faces, 20newsgroups, 20newsgroups_vectorized, lfw_people, lfw_pairs, covtype, rcv1, kddcup99.')
    parser.add_argument('--normalize_D', action='store_true', default=DEFAULT_NORMALIZE_D, help='Specify if we want to normalize the prediction of the forest by doing the L2 division of the pred vectors.')
    parser.add_argument('--dataset_normalizer', nargs='?', type=str, default=DEFAULT_DATASET_NORMALIZER, help='Specify which dataset normalizer use (either standard, minmax, robust or normalizer).')
    parser.add_argument('--forest_size', nargs='?', type=int, default=DEFAULT_FOREST_SIZE, help='The number of trees of the random forest.')
    parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=DEFAULT_EXTRACTED_FOREST_SIZE, help='The number of trees selected by OMP.')
    parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
    parser.add_argument('--dev_size', nargs='?', type=float, default=DEFAULT_DEV_SIZE, help='Dev subset ratio.')
    parser.add_argument('--test_size', nargs='?', type=float, default=DEFAULT_TEST_SIZE, help='Test subset ratio.')
    parser.add_argument('--random_seed_number', nargs='?', type=int, default=DEFAULT_RANDOM_SEED_NUMBER, help='Number of random seeds used.')
    parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly')
    parser.add_argument('--train_on_subset', nargs='?', type=str, default=DEFAULT_TRAIN_ON_SUBSET, help='Specify on witch subset the model will be trained (either train or dev).')
    parser.add_argument('--disable_progress', action='store_true', default=DEFAULT_DISABLE_PROGRESS, help='Disable the progress bars.')
    args = parser.parse_args()

    pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)

    logger = LoggerFactory.create(LOG_PATH, os.path.basename(__file__))

    args.extracted_forest_size = args.extracted_forest_size \
        if type(args.extracted_forest_size) == list \
        else [args.extracted_forest_size]

    if args.seeds != None and args.random_seed_number > 1:
        logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')    

    seeds = args.seeds if args.seeds is not None \
        else [random.randint(begin_random_seed_range, end_random_seed_range) \
        for i in range(args.random_seed_number)]

    experiment_id = resolve_experiment_id(args.models_dir)
    experiment_id_str = str(experiment_id)

    logger.info('Experiment id: {}'.format(experiment_id_str))

    with tqdm(seeds, disable=args.disable_progress) as seed_bar:
        for seed in seed_bar:
            seed_bar.set_description('seed={}'.format(seed))
            seed_str = str(seed)
            models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
                os.sep + seed_str
            pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True)

            dataset_parameters = DatasetParameters(
                name=args.dataset_name,
                test_size=args.test_size,
                dev_size=args.dev_size,
                random_state=seed,
                dataset_normalizer=args.dataset_normalizer,
                train_on_subset=args.train_on_subset
            )
            dataset_parameters.save(models_dir, experiment_id_str)

            dataset = DatasetLoader.load(dataset_parameters)

            trainer = Trainer(dataset)

            with tqdm(args.extracted_forest_size, disable=args.disable_progress) as extracted_forest_size_bar:
                for extracted_forest_size in extracted_forest_size_bar:
                    extracted_forest_size_bar.set_description('extracted_forest_size={}'.format(extracted_forest_size))
                    sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
                    pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)

                    model_parameters = ModelParameters(
                        forest_size=args.forest_size,
                        extracted_forest_size=extracted_forest_size,
                        seed=seed,
                        normalize_D=args.normalize_D
                    )
                    model_parameters.save(sub_models_dir, experiment_id)

                    model = ModelFactory.build(dataset.task, model_parameters)

                    trainer.train(model, sub_models_dir)