train.py 5.67 KB
Newer Older
1
2
3
from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
4
from bolsonaro.models.model_parameters import ModelParameters
5
from bolsonaro.trainer import Trainer
6
from bolsonaro.utils import resolve_experiment_id
7
8
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory
9

10
from dotenv import find_dotenv, load_dotenv
11
12
13
import argparse
import pathlib
import random
14
import os
15
from tqdm import tqdm
16
17
18


if __name__ == "__main__":
19
    # get environment variables in .env
20
    load_dotenv(find_dotenv('.env.example'))
21

22
    default_dataset_name = 'boston'
23
    default_normalize = True
24
    default_wo_normalization = False
25
26
    default_forest_size = 100
    default_extracted_forest_size = 10
27
28
    # the models will be stored in a directory structure like: models/{experiment_id}/seeds/{seed_nb}/extracted_forest_size/{nb_extracted_trees}
    default_models_dir = os.environ["project_dir"] + os.sep + 'models'
29
30
31
32
33
    default_dev_size = 0.2
    default_test_size = 0.2
    default_random_seed_number = 1
    begin_random_seed_range = 1
    end_random_seed_range = 2000
34
    default_train_on_subset = 'train'
35
36
37

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset_name', nargs='?', type=str, default=default_dataset_name, help='Specify the dataset. Regression: boston, diabetes, linnerud, california_housing. Classification: iris, digits, wine, breast_cancer, olivetti_faces, 20newsgroups, 20newsgroups_vectorized, lfw_people, lfw_pairs, covtype, rcv1, kddcup99.')
38
    parser.add_argument('--wo_normalization', action='store_true', default=default_wo_normalization, help='Withouyt normalize the data by doing the L2 division of the pred vectors.')
39
40
41
    parser.add_argument('--forest_size', nargs='?', type=int, default=default_forest_size, help='The number of trees of the random forest.')
    parser.add_argument('--extracted_forest_size', nargs='+', type=int, default=default_extracted_forest_size, help='The number of trees selected by OMP.')
    parser.add_argument('--models_dir', nargs='?', type=str, default=default_models_dir, help='The output directory of the trained models.')
42
43
44
45
46
    parser.add_argument('--dev_size', nargs='?', type=float, default=default_dev_size, help='Dev subset ratio.')
    parser.add_argument('--test_size', nargs='?', type=float, default=default_test_size, help='Test subset ratio.')
    parser.add_argument('--random_seed_number', nargs='?', type=int, default=default_random_seed_number, help='Number of random seeds used.')
    parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly')
    parser.add_argument('--train_on_subset', nargs='?', type=str, default=default_train_on_subset, help='Specify on witch subset the model will be trained (either train or dev).')
47
48
49
50
    args = parser.parse_args()

    pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True)

51
    logger = LoggerFactory.create(LOG_PATH, os.path.basename(__file__))
52

53
54
55
56
    args.extracted_forest_size = args.extracted_forest_size \
        if type(args.extracted_forest_size) == list \
        else [args.extracted_forest_size]

57
    if args.seeds != None and args.random_seed_number > 1:
58
59
        logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')    

60
61
62
    seeds = args.seeds if args.seeds is not None \
        else [random.randint(begin_random_seed_range, end_random_seed_range) \
        for i in range(args.random_seed_number)]
63

64
65
66
    normalize = default_normalize and args.wo_normalization is False
    logger.debug('normalize={}'.format(normalize))

67
68
69
    experiment_id = resolve_experiment_id(args.models_dir)
    experiment_id_str = str(experiment_id)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    with tqdm(seeds) as seed_bar:
        for seed in seed_bar:
            seed_bar.set_description('seed={}'.format(seed))
            seed_str = str(seed)
            models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \
                os.sep + seed_str
            pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True)

            dataset_parameters = DatasetParameters(
                name=args.dataset_name,
                test_size=args.test_size,
                dev_size=args.dev_size,
                random_state=seed,
                normalize=normalize,
                train_on_subset=args.train_on_subset
85
            )
86
            dataset_parameters.save(models_dir, experiment_id_str)
87

88
            dataset = DatasetLoader.load(dataset_parameters)
89

90
            trainer = Trainer(dataset)
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            with tqdm(args.extracted_forest_size) as extracted_forest_size_bar:
                for extracted_forest_size in extracted_forest_size_bar:
                    extracted_forest_size_bar.set_description('extracted_forest_size={}'.format(extracted_forest_size))
                    sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)
                    pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)

                    model_parameters = ModelParameters(
                        forest_size=args.forest_size,
                        extracted_forest_size=extracted_forest_size,
                        seed=seed,
                        normalize=normalize
                    )
                    model_parameters.save(sub_models_dir, experiment_id)

                    model = ModelFactory.build(dataset.task, model_parameters)

                    trainer.train(model, sub_models_dir)