from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.trainer import Trainer
from bolsonaro.utils import resolve_experiment_id

import argparse
import pathlib
import random
import os
import errno


if __name__ == "__main__":
    default_dataset_name = 'boston'
    default_normalize = False
    default_forest_size = 100
    default_extracted_forest_size = 10
    default_models_dir = 'models'
    default_dev_size = 0.2
    default_test_size = 0.2
    default_use_random_seed = True
    default_random_seed_number = 1
    begin_random_seed_range = 1
    end_random_seed_range = 2000

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset_name', nargs='?', type=str, default=default_dataset_name, help='Specify the dataset. Regression: boston, diabetes, linnerud, california_housing. Classification: iris, digits, wine, breast_cancer, olivetti_faces, 20newsgroups, 20newsgroups_vectorized, lfw_people, lfw_pairs, covtype, rcv1, kddcup99.')
    parser.add_argument('--normalize', action='store_true', default=default_normalize, help='Normalize the data by doing the L2 division of the pred vectors.')
    parser.add_argument('--forest_size', nargs='?', type=int, default=default_forest_size, help='The number of trees of the random forest.')
    parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=default_extracted_forest_size, help='The number of trees selected by OMP.')
    parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
    parser.add_argument('--dev_size', nargs='?', type=float, default=default_dev_size, help='Dev subset ratio')
    parser.add_argument('--test_size', nargs='?', type=float, default=default_test_size, help='Test subset ratio')
    parser.add_argument('--use_random_seed', action='store_true', default=default_use_random_seed, help='Random seed used for the data split')
    parser.add_argument('--random_seed_number', nargs='?', type=int, default=default_random_seed_number, help='Number of random seeds used')
    args = parser.parse_args()

    pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)

    args.extracted_forest_size = args.extracted_forest_size \
        if type(args.extracted_forest_size) == list \
        else [args.extracted_forest_size]

    random_seeds = [random.randint(begin_random_seed_range, end_random_seed_range) \
        for i in range(args.random_seed_number)] \
        if args.use_random_seed else None

    experiment_id = resolve_experiment_id(args.models_dir)
    experiment_id_str = str(experiment_id)

    for random_seed in random_seeds:
        random_seed_str = str(random_seed)
        models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
            os.sep + random_seed_str
        try:
            os.makedirs(models_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

        dataset_parameters = DatasetParameters(
            name=args.dataset_name,
            test_size=args.test_size,
            dev_size=args.dev_size,
            random_state=random_seed,
            normalize=args.normalize
        )
        dataset_parameters.save(models_dir, experiment_id_str)

        dataset = DatasetLoader.load_from_name(dataset_parameters)

        trainer = Trainer(dataset)

        for extracted_forest_size in args.extracted_forest_size:
            sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
            try:
                os.makedirs(sub_models_dir)
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise

            model_parameters = ModelParameters(
                forest_size=args.forest_size,
                extracted_forest_size=extracted_forest_size,
                seed=random_seed
            )
            model_parameters.save(sub_models_dir, experiment_id)

            model = ModelFactory.build(dataset.task, model_parameters)

            trainer.iterate(model, sub_models_dir)