train.py 20.6 KB
Newer Older
1
2
3
from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.dataset_loader import DatasetLoader
from bolsonaro.models.model_factory import ModelFactory
4
from bolsonaro.models.model_parameters import ModelParameters
5
from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
6
from bolsonaro.trainer import Trainer
7
from bolsonaro.utils import resolve_experiment_id, tqdm_joblib
8
9
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory
10

11
from dotenv import find_dotenv, load_dotenv
12
import argparse
13
import copy
14
import json
15
16
import pathlib
import random
17
import os
18
from joblib import Parallel, delayed
19
import threading
20
import json
21
from tqdm import tqdm
22
import numpy as np
23
import shutil
24
25


26
def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verbose):
27
28
29
30
31
32
33
34
35
36
    """
    Experiment function.

    Will be used as base function for worker in multithreaded application.

    :param seed:
    :param parameters:
    :param experiment_id:
    :return:
    """
37
38
    logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_ti{}'.format(
        seed, threading.get_ident()))
39

40
41
    seed_str = str(seed)
    experiment_id_str = str(experiment_id)
42
    models_dir = parameters['models_dir'] + os.sep + experiment_id_str + os.sep + 'seeds' + \
43
44
45
46
        os.sep + seed_str
    pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True)

    dataset_parameters = DatasetParameters(
47
48
49
        name=parameters['dataset_name'],
        test_size=parameters['test_size'],
        dev_size=parameters['dev_size'],
50
        random_state=seed,
51
        dataset_normalizer=parameters['dataset_normalizer']
52
53
54
55
56
57
    )
    dataset_parameters.save(models_dir, experiment_id_str)
    dataset = DatasetLoader.load(dataset_parameters)

    trainer = Trainer(dataset)

58
59
60
61
62
    if parameters['extraction_strategy'] == 'ensemble':
        library = EnsembleSelectionForestRegressor.generate_library(dataset.X_train, dataset.y_train, random_state=seed)
    else:
        library = None

63
64
65
66
67
68
69
70
71
72
73
    if parameters['extraction_strategy'] == 'random':
        pretrained_model_parameters = ModelParameters(
            extracted_forest_size=parameters['forest_size'],
            normalize_D=parameters['normalize_D'],
            subsets_used=parameters['subsets_used'],
            normalize_weights=parameters['normalize_weights'],
            seed=seed,
            hyperparameters=hyperparameters,
            extraction_strategy=parameters['extraction_strategy']
        )
        pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters, library=library)
74
75
76
77
78
79
        pretraned_trainer = Trainer(dataset)
        pretraned_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
        pretrained_estimator.fit(
            X=pretraned_trainer._X_forest,
            y=pretraned_trainer._y_forest
        )
80
81
82
83
    else:
        pretrained_estimator = None
        pretrained_model_parameters = None

84
    if parameters['extraction_strategy'] != 'none':
85
86
        with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
            Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
87
88
                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters)
89
                for i in range(len(parameters['extracted_forest_size'])))
90
91
92
    else:
        forest_size = hyperparameters['n_estimators']
        logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
93
        sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
        # Check if the result file already exists
        already_exists = False
        if os.path.isdir(sub_models_dir):
            sub_models_dir_files = os.listdir(sub_models_dir)
            for file_name in sub_models_dir_files:
                if '.pickle' != os.path.splitext(file_name)[1]:
                    continue
                else:
                    already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                    break
        if already_exists:
            logger.info('Base forest result already exists. Skipping...')
        else:
            pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
109
            model_parameters = ModelParameters(
110
                extracted_forest_size=forest_size,
111
112
113
114
115
116
117
118
119
                normalize_D=parameters['normalize_D'],
                subsets_used=parameters['subsets_used'],
                normalize_weights=parameters['normalize_weights'],
                seed=seed,
                hyperparameters=hyperparameters,
                extraction_strategy=parameters['extraction_strategy']
            )
            model_parameters.save(sub_models_dir, experiment_id)

120
            model = ModelFactory.build(dataset.task, model_parameters, library=library)
121

122
            trainer.init(model, subsets_used=parameters['subsets_used'])
123
124
            trainer.train(model)
            trainer.compute_results(model, sub_models_dir)
125
126
127
128
    logger.info(f'Training done for seed {seed_str}')
    seed_job_pb.update(1)

def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_size, models_dir,
129
130
    seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
    pretrained_estimator=None, pretrained_model_parameters=None):
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_extracted_forest_size{}_ti{}'.format(
        seed, extracted_forest_size, threading.get_ident()))
    logger.info('extracted_forest_size={}'.format(extracted_forest_size))

    sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)

    # Check if the result file already exists
    already_exists = False
    if os.path.isdir(sub_models_dir):
        sub_models_dir_files = os.listdir(sub_models_dir)
        for file_name in sub_models_dir_files:
            if '.pickle' != os.path.splitext(file_name)[1]:
                return
            else:
                already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                break
    if already_exists:
        logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
        return

    pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)

154
155
156
157
158
159
160
161
162
163
164
165
166
    if not pretrained_estimator:
        model_parameters = ModelParameters(
            extracted_forest_size=extracted_forest_size,
            normalize_D=parameters['normalize_D'],
            subsets_used=parameters['subsets_used'],
            normalize_weights=parameters['normalize_weights'],
            seed=seed,
            hyperparameters=hyperparameters,
            extraction_strategy=parameters['extraction_strategy']
        )
        model_parameters.save(sub_models_dir, experiment_id)
        model = ModelFactory.build(dataset.task, model_parameters, library=library)
    else:
167
        model = copy.deepcopy(pretrained_estimator)
168
        pretrained_model_parameters.save(sub_models_dir, experiment_id)
169
170

    trainer.init(model, subsets_used=parameters['subsets_used'])
171
    trainer.train(model, extracted_forest_size=extracted_forest_size)
172
    trainer.compute_results(model, sub_models_dir)
173

174
"""
175
Command lines example for stage 1:
176
177
178
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=none --save_experiment_configuration 1 none_with_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=random --save_experiment_configuration 1 random_with_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 1 omp_with_params --extracted_forest_size_stop=0.05
179
180
181
182
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=none --skip_best_hyperparams --save_experiment_configuration 1 none_wo_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=random --skip_best_hyperparams --save_experiment_configuration 1 random_wo_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --skip_best_hyperparams --save_experiment_configuration 1 omp_wo_params --extracted_forest_size_stop=0.05
python code/compute_results.py --stage 1 --experiment_ids 1 2 3 4 5 6 --dataset_name=california_housing
183

184
Command lines example for stage 2:
185
186
187
188
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 2 no_normalization --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 2 normalize_D --normalize_D --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 2 normalize_weights --normalize_weights --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 2 normalize_D_and_weights --normalize_D --normalize_weights --extracted_forest_size_stop=0.05
189
python code/compute_results.py --stage 2 --experiment_ids 7 8 9 10 --dataset_name=california_housing
190
191
192
193
194

Command lines example for stage 3:
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 3 train-dev_subset --extracted_forest_size_stop=0.05 --subsets_used train,dev
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 3 train-dev_train-dev_subset --extracted_forest_size_stop=0.05 --subsets_used train+dev,train+dev
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 3 train-train-dev_subset --extracted_forest_size_stop=0.05 --subsets_used train,train+dev
195
python code/compute_results.py --stage 3 --experiment_ids 11 12 13 --dataset_name=california_housing
Charly Lamothe's avatar
Charly Lamothe committed
196
197
198
199
200
201

Command lines example for stage 4:
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=none --save_experiment_configuration 4 none_with_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --extraction_strategy=random --save_experiment_configuration 4 random_with_params --extracted_forest_size_stop=0.05
python code/train.py --dataset_name=california_housing --seeds 1 2 3 4 5 --save_experiment_configuration 4 omp_with_params --extracted_forest_size_stop=0.05 --subsets_used train+dev,train+dev
python code/compute_results.py --stage 4 --experiment_ids 1 2 3 --dataset_name=california_housing
202
"""
203
if __name__ == "__main__":
204
    load_dotenv(find_dotenv('.env'))
205
    DEFAULT_EXPERIMENT_CONFIGURATION_PATH = 'experiments'
206
    # the models will be stored in a directory structure like: models/{experiment_id}/seeds/{seed_nb}/extracted_forest_sizes/{extracted_forest_size}
207
    DEFAULT_MODELS_DIR = os.environ['project_dir'] + os.sep + 'models'
208
    DEFAULT_VERBOSE = False
209
210
    DEFAULT_SKIP_BEST_HYPERPARAMS = False
    DEFAULT_JOB_NUMBER = -1
211
    DEFAULT_EXTRACTION_STRATEGY = 'omp'
212
    DEFAULT_OVERWRITE = False
Charly LAMOTHE's avatar
Charly LAMOTHE committed
213

214
215
216
217
    begin_random_seed_range = 1
    end_random_seed_range = 2000

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
218
    parser.add_argument('--experiment_id', nargs='?', type=int, default=None, help='Specify an experiment id. Remove already existing model with this specified experiment id.')
219
220
    parser.add_argument('--experiment_configuration', nargs='?', type=str, default=None, help='Specify an experiment configuration file name. Overload all other parameters.')
    parser.add_argument('--experiment_configuration_path', nargs='?', type=str, default=DEFAULT_EXPERIMENT_CONFIGURATION_PATH, help='Specify the experiment configuration directory path.')
221
222
223
    parser.add_argument('--dataset_name', nargs='?', type=str, default=DatasetLoader.DEFAULT_DATASET_NAME, help='Specify the dataset. Regression: boston, diabetes, linnerud, california_housing. Classification: iris, digits, wine, breast_cancer, olivetti_faces, 20newsgroups, 20newsgroups_vectorized, lfw_people, lfw_pairs, covtype, rcv1, kddcup99.')
    parser.add_argument('--normalize_D', action='store_true', default=DatasetLoader.DEFAULT_NORMALIZE_D, help='Specify if we want to normalize the prediction of the forest by doing the L2 division of the pred vectors.')
    parser.add_argument('--dataset_normalizer', nargs='?', type=str, default=DatasetLoader.DEFAULT_DATASET_NORMALIZER, help='Specify which dataset normalizer use (either standard, minmax, robust or normalizer).')
224
225
    parser.add_argument('--forest_size', nargs='?', type=int, default=None, help='The number of trees of the random forest.')
    parser.add_argument('--extracted_forest_size_samples', nargs='?', type=int, default=DatasetLoader.DEFAULT_EXTRACTED_FOREST_SIZE_SAMPLES, help='The number of extracted forest sizes (proportional to the forest size) selected by OMP.')
226
    parser.add_argument('--extracted_forest_size_stop', nargs='?', type=float, default=DatasetLoader.DEFAULT_EXTRACTED_FOREST_SIZE_STOP, help='Specify the upper bound of the extracted forest sizes linspace.')
Charly LAMOTHE's avatar
Charly LAMOTHE committed
227
    parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
228
229
230
    parser.add_argument('--dev_size', nargs='?', type=float, default=DatasetLoader.DEFAULT_DEV_SIZE, help='Dev subset ratio.')
    parser.add_argument('--test_size', nargs='?', type=float, default=DatasetLoader.DEFAULT_TEST_SIZE, help='Test subset ratio.')
    parser.add_argument('--random_seed_number', nargs='?', type=int, default=DatasetLoader.DEFAULT_RANDOM_SEED_NUMBER, help='Number of random seeds used.')
231
    parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly')
232
    parser.add_argument('--subsets_used', nargs='?', type=str, default=DatasetLoader.DEFAULT_SUBSETS_USED, help='train,dev: forest on train, OMP on dev. train+dev,train+dev: both forest and OMP on train+dev. train,train+dev: forest on train+dev and OMP on dev.')
233
    parser.add_argument('--normalize_weights', action='store_true', default=DatasetLoader.DEFAULT_NORMALIZE_WEIGHTS, help='Divide the predictions by the weights sum.')
234
    parser.add_argument('--verbose', action='store_true', default=DEFAULT_VERBOSE, help='Print tqdm progress bar.')
235
236
237
    parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
    parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
    parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
238
    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity, kmeans, ensemble.')
239
    parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
240
241
    args = parser.parse_args()

242
243
244
245
246
247
248
    if args.experiment_configuration:
        with open(args.experiment_configuration_path + os.sep + \
            args.experiment_configuration + '.json', 'r') as input_file:
            parameters = json.load(input_file)
    else:
        parameters = args.__dict__

249
    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity', 'kmeans', 'ensemble']:
250
251
        raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters.extraction_strategy))

252
    pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)
253

254
    logger = LoggerFactory.create(LOG_PATH, os.path.basename(__file__))
255

Charly Lamothe's avatar
Charly Lamothe committed
256
    hyperparameters_path = os.path.join('experiments', args.dataset_name, 'stage1', 'params.json')
Charly Lamothe's avatar
Charly Lamothe committed
257
    if os.path.exists(hyperparameters_path):
Charly Lamothe's avatar
Charly Lamothe committed
258
259
        logger.info("Hyperparameters found for this dataset at '{}'".format(hyperparameters_path))
        with open(hyperparameters_path, 'r+') as file_hyperparameter:
Charly Lamothe's avatar
Charly Lamothe committed
260
261
262
263
264
            loaded_hyperparameters = json.load(file_hyperparameter)['best_parameters']
            if args.skip_best_hyperparams:
                hyperparameters = {'n_estimators': loaded_hyperparameters['n_estimators']}
            else:
                hyperparameters = loaded_hyperparameters
265
266
267
    else:
        hyperparameters = {}

268
269
270
271
272
273
274
275
276
277
    """
    First case: no best hyperparameters are specified and no forest_size parameter
    specified in argument, so use the DEFAULT_FOREST_SIZE.
    Second case: no matter if hyperparameters are specified, the forest_size parameter
    will override it.
    Third implicit case: use the number of estimators found in specified hyperparameters.
    """
    if len(hyperparameters) == 0 and parameters['forest_size'] is None:
        hyperparameters['n_estimators'] = DatasetLoader.DEFAULT_FOREST_SIZE
    elif parameters['forest_size'] is not None:
Charly Lamothe's avatar
Charly Lamothe committed
278
279
        hyperparameters['n_estimators'] = parameters['forest_size']

280
    # The number of tree to extract from forest (K)
281
    parameters['extracted_forest_size'] = np.unique(np.around(hyperparameters['n_estimators'] *
282
283
        np.linspace(0, args.extracted_forest_size_stop,
        parameters['extracted_forest_size_samples'] + 1,
Léo Bouscarrat's avatar
Léo Bouscarrat committed
284
        endpoint=True)[1:]).astype(np.int)).tolist()
285

286
    if parameters['seeds'] != None and parameters['random_seed_number'] > 1:
287
288
        logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')    

289
    # Seeds are either provided as parameters or generated at random
290
    seeds = parameters['seeds'] if parameters['seeds'] is not None \
291
        else [random.randint(begin_random_seed_range, end_random_seed_range) \
292
        for i in range(parameters['random_seed_number'])]
293

294
295
    if args.experiment_id:
        experiment_id = args.experiment_id
296
297
        if args.overwrite:
            shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
298
299
300
    else:
        # Resolve the next experiment id number (last id + 1)
        experiment_id = resolve_experiment_id(parameters['models_dir'])
301
302
    logger.info('Experiment id: {}'.format(experiment_id))

303
    """
304
    If the experiment configuration isn't coming from
305
    an already existing file, save it to a json file to
306
    keep trace of it (either a specified path, either in 'unnamed' dir.).
307
308
    """
    if args.experiment_configuration is None:
309
310
311
        if args.save_experiment_configuration:
            if len(args.save_experiment_configuration) != 2:
                raise ValueError('save_experiment_configuration must have two parameters.')
Charly Lamothe's avatar
Charly Lamothe committed
312
313
            elif int(args.save_experiment_configuration[0]) not in list(range(1, 6)):
                raise ValueError('save_experiment_configuration first parameter must be a supported stage id (i.e. [1, 5]).')
314
315
316
317
            output_experiment_stage_path = os.path.join(args.experiment_configuration_path,
                args.dataset_name, 'stage' + args.save_experiment_configuration[0])
            pathlib.Path(output_experiment_stage_path).mkdir(parents=True, exist_ok=True)
            output_experiment_configuration_path = os.path.join(output_experiment_stage_path,
318
                args.save_experiment_configuration[1] + '.json')
319
320
321
322
323
324
        else:
            pathlib.Path(os.path.join(args.experiment_configuration_path, 'unnamed')).mkdir(parents=True, exist_ok=True)
            output_experiment_configuration_path = os.path.join(
                args.experiment_configuration_path, 'unnamed', 'unnamed_{}.json'.format(
                experiment_id))
        with open(output_experiment_configuration_path, 'w') as output_file:
325
326
327
328
329
330
            json.dump(
                parameters,
                output_file,
                indent=4
            )

331
    # Run as much job as there are seeds
332
333
334
    with tqdm_joblib(tqdm(total=len(seeds), disable=not args.verbose)) as seed_job_pb:
        Parallel(n_jobs=args.job_number)(delayed(seed_job)(seed_job_pb, seeds[i],
            parameters, experiment_id, hyperparameters, args.verbose) for i in range(len(seeds)))