diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 09562958a0c1a58c8a82dde4cbbcd538f02c54d0..9fea5053f83a774026ac69c5ed7da47a6a36a296 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -42,10 +42,14 @@ class Trainer(object): def base_score_metric_name(self): return self._base_score_metric_name - def init(self, model): + def init(self, model, subsets_used='train,dev'): if type(model) in [RandomForestRegressor, RandomForestClassifier]: - self._X_forest = self._dataset.X_train - self._y_forest = self._dataset.y_train + if subsets_used == 'train,dev': + self._X_forest = self._dataset.X_train + self._y_forest = self._dataset.y_train + else: + self._X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev]) + self._y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev]) self._logger.debug('Fitting the forest on train subset') elif model.models_parameters.subsets_used == 'train,dev': self._X_forest = self._dataset.X_train diff --git a/code/compute_results.py b/code/compute_results.py index c8c90653cdd493b5cad5bfe1cefd4c58747bdb04..473044d2fd05deeeeb86d927abd3a13ee35bd5de 100644 --- a/code/compute_results.py +++ b/code/compute_results.py @@ -328,7 +328,7 @@ if __name__ == "__main__": ylabel=experiments_score_metric, title='Loss values of {}\nusing different training subsets'.format(args.dataset_name))""" elif args.stage == 4: - if len(args.experiment_ids) != 5: + if len(args.experiment_ids) != 3: raise ValueError('In the case of stage 4, the number of specified experiment ids must be 3.') # Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary @@ -351,7 +351,7 @@ if __name__ == "__main__": omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes( args.models_dir, args.results_dir, args.experiment_ids[2]) - # base_with_params + """# base_with_params logger.info('Loading base_with_params experiment scores 2...') _, _, base_with_params_test_scores_2, \ _ = \ @@ -361,7 +361,7 @@ if __name__ == "__main__": logger.info('Loading random_with_params experiment scores 2...') _, _, random_with_params_test_scores_2, \ _, _ = \ - extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4]) + extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4])""" # Sanity check on the metrics retreived if not (base_with_params_experiment_score_metric == random_with_params_experiment_score_metric @@ -374,10 +374,8 @@ if __name__ == "__main__": Plotter.plot_stage2_losses( file_path=output_path + os.sep + 'losses.png', - all_experiment_scores=[base_with_params_test_scores, base_with_params_test_scores_2, random_with_params_test_scores, - random_with_params_test_scores_2, - omp_with_params_test_scores], - all_labels=['base_train-dev', 'base', 'random_train-dev', 'random', 'omp'], + all_experiment_scores=[base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores], + all_labels=['base', 'random', 'omp'], x_value=with_params_extracted_forest_sizes, xlabel='Number of trees extracted', ylabel=experiments_score_metric, diff --git a/code/train.py b/code/train.py index 0931191ac4872a45ec615d65420886128540f1a1..1131f2bf390f545385654ae59aea65a54e3f9977 100644 --- a/code/train.py +++ b/code/train.py @@ -73,7 +73,7 @@ def process_job(seed, parameters, experiment_id, hyperparameters): model = ModelFactory.build(dataset.task, model_parameters) - trainer.init(model) + trainer.init(model, subsets_used=parameters['subsets_used']) trainer.train(model) trainer.compute_results(model, sub_models_dir) else: @@ -95,7 +95,7 @@ def process_job(seed, parameters, experiment_id, hyperparameters): model = ModelFactory.build(dataset.task, model_parameters) - trainer.init(model) + trainer.init(model, subsets_used=parameters['subsets_used']) trainer.train(model) trainer.compute_results(model, sub_models_dir) logger.info('Training done') diff --git a/results/boston/stage4/losses.png b/results/boston/stage4/losses.png index ab2f45d2f8744176f158dd9d070c9dba7d8100a6..c5d57ce0b386934e9bd2cadcce5b44f8fb8a40d4 100644 Binary files a/results/boston/stage4/losses.png and b/results/boston/stage4/losses.png differ diff --git a/results/breast_cancer/stage4/losses.png b/results/breast_cancer/stage4/losses.png index 44268f88352ea22c7a326dcffc71e71070ac7274..0cfd248266a4eb867bbe8d1815c36fa44521374d 100644 Binary files a/results/breast_cancer/stage4/losses.png and b/results/breast_cancer/stage4/losses.png differ diff --git a/results/california_housing/stage4/losses.png b/results/california_housing/stage4/losses.png index fc285aad17f21b80886457d5915fdf723309484e..d8e6b8087bd8abea41d556853d115a5ce0ee7b35 100644 Binary files a/results/california_housing/stage4/losses.png and b/results/california_housing/stage4/losses.png differ diff --git a/results/diabetes/stage4/losses.png b/results/diabetes/stage4/losses.png index f2fbccddb99359a9b12250ce339765f4e4946131..ae057e660cc526be5002813ce32b8998f5746f6b 100644 Binary files a/results/diabetes/stage4/losses.png and b/results/diabetes/stage4/losses.png differ diff --git a/results/diamonds/stage4/losses.png b/results/diamonds/stage4/losses.png index d6da7b1cb6f9d17de1dff72af5dc11f124a310c2..00f0fb9373250dea67c59587fbc52cf7774e0378 100644 Binary files a/results/diamonds/stage4/losses.png and b/results/diamonds/stage4/losses.png differ diff --git a/results/digits/stage4/losses.png b/results/digits/stage4/losses.png index cabee4be51b813c38a66a1d478f3d3fa7f2a3181..1b28b2f47d19d5a82fb21f186a5e8212ce623c51 100644 Binary files a/results/digits/stage4/losses.png and b/results/digits/stage4/losses.png differ diff --git a/results/iris/stage4/losses.png b/results/iris/stage4/losses.png index f9a57ccb6de7c466babbe1c78ec29a429cb50a9d..cffa172cc4d8af8b53874d9030ad71806638577c 100644 Binary files a/results/iris/stage4/losses.png and b/results/iris/stage4/losses.png differ diff --git a/results/lfw_pairs/stage4/losses.png b/results/lfw_pairs/stage4/losses.png new file mode 100644 index 0000000000000000000000000000000000000000..ddb71771aa7b2058d0020ad31a5431175eade83f Binary files /dev/null and b/results/lfw_pairs/stage4/losses.png differ diff --git a/results/olivetti_faces/stage4/losses.png b/results/olivetti_faces/stage4/losses.png index 0dfb13b327925ab329e1e3b014d65a880fe596d0..862840acf9243e92052825fb35522e5bc3f7a3a4 100644 Binary files a/results/olivetti_faces/stage4/losses.png and b/results/olivetti_faces/stage4/losses.png differ diff --git a/results/wine/stage4/losses.png b/results/wine/stage4/losses.png index 2f71e9a25f311da1014cb2148c6d88cbb616fa64..d31ee5233c47047047499cfe15caf876d8f9fe06 100644 Binary files a/results/wine/stage4/losses.png and b/results/wine/stage4/losses.png differ