Skip to content
Snippets Groups Projects
Commit 994ff6ba authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Update fix of normalize_D in the case of wo weights. Add density plot in...

Update fix of normalize_D in the case of wo weights. Add density plot in compute results for all stages and add some plots in results.
parent a1a7f767
No related branches found
No related tags found
No related merge requests found
Showing
with 50 additions and 18 deletions
......@@ -123,9 +123,7 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X)
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
return self._make_omp_weighted_prediction(forest_predictions, self._omp, self._models_parameters.normalize_weights)
......
......@@ -119,9 +119,7 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = []
preds = []
......@@ -149,7 +147,9 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = []
preds = []
......
......@@ -156,6 +156,7 @@ if __name__ == "__main__":
DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
DEFAULT_PLOT_WEIGHT_DENSITY = False
DEFAULT_WO_LOSS_PLOTS = False
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
......@@ -168,6 +169,7 @@ if __name__ == "__main__":
parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.')
args = parser.parse_args()
if args.stage not in list(range(1, 6)):
......@@ -181,7 +183,7 @@ if __name__ == "__main__":
# Create recursively the results dir tree
pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
if args.stage == 1:
if args.stage == 1 and not args.wo_loss_plots:
if len(args.experiment_ids) != 6:
raise ValueError('In the case of stage 1, the number of specified experiment ids must be 6.')
......@@ -221,8 +223,8 @@ if __name__ == "__main__":
wo_params_extracted_forest_sizes, random_wo_params_experiment_score_metric = \
extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[4]))
# base_wo_params
logger.info('Loading base_wo_params experiment scores...')
# omp_wo_params
logger.info('Loading omp_wo_params experiment scores...')
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores, _, \
omp_wo_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[5]))
......@@ -262,7 +264,7 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
)
elif args.stage == 2:
elif args.stage == 2 and not args.wo_loss_plots:
if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
......@@ -308,7 +310,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different normalizations'.format(args.dataset_name))
elif args.stage == 3:
elif args.stage == 3 and not args.wo_loss_plots:
if len(args.experiment_ids) != 3:
raise ValueError('In the case of stage 3, the number of specified experiment ids must be 3.')
......@@ -365,7 +367,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different training subsets'.format(args.dataset_name))"""
elif args.stage == 4:
elif args.stage == 4 and not args.wo_loss_plots:
if len(args.experiment_ids) != 3:
raise ValueError('In the case of stage 4, the number of specified experiment ids must be 3.')
......@@ -427,11 +429,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, args.experiment_ids[2])
Plotter.weight_density(experiment_weights, os.path.join(output_path, 'weight_density.png'))
elif args.stage == 5:
elif args.stage == 5 and not args.wo_loss_plots:
# Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary
extracted_forest_sizes_number = retreive_extracted_forest_sizes_number(args.models_dir, int(args.experiment_ids[1]))
all_labels = list()
......@@ -475,8 +473,9 @@ if __name__ == "__main__":
continue
logger.info(f'Loading {label} experiment scores...')
current_experiment_id = int(args.experiment_ids[i].split('=')[1])
_, _, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[i].split('=')[1]))
args.models_dir, args.results_dir, current_experiment_id)
all_labels.append(label)
all_scores.append(current_test_scores)
......@@ -491,7 +490,42 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=base_with_params_experiment_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
if args.plot_weight_density:
root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage{args.stage}')
if args.stage == 1:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2]), ('omp_wo_params', args.experiment_ids[2])]
elif args.stage == 2:
omp_experiment_ids = [('no_normalization', args.experiment_ids[0]),
('normalize_D', args.experiment_ids[1]),
('normalize_weights', args.experiment_ids[2]),
('normalize_D_and_weights', args.experiment_ids[3])]
elif args.stage == 3:
omp_experiment_ids = [('train-dev_subset', args.experiment_ids[0]),
('train-dev_train-dev_subset', args.experiment_ids[1]),
('train-train-dev_subset', args.experiment_ids[2])]
elif args.stage == 4:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2])]
elif args.stage == 5:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2])]
for i in range(3, len(args.experiment_ids)):
if 'kmeans' in args.experiment_ids[i]:
label = 'kmeans'
elif 'similarity' in args.experiment_ids[i]:
label = 'similarity'
elif 'ensemble' in args.experiment_ids[i]:
label = 'ensemble'
else:
raise ValueError('This stage number is not supported yet, but it will be!')
logger.error('Invalid value encountered')
continue
current_experiment_id = int(args.experiment_ids[i].split('=')[1])
omp_experiment_ids.append((label, current_experiment_id))
for (experiment_label, experiment_id) in omp_experiment_ids:
logger.info(f'Computing weight density plot for experiment {experiment_label}...')
experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
logger.info('Done.')
results/boston/stage1/weight_density_omp_with_params.png

94.6 KiB

results/boston/stage1/weight_density_omp_wo_params.png

94.6 KiB

results/boston/stage3/weight_density_train-dev_subset.png

31.3 KiB

results/boston/stage3/weight_density_train-dev_train-dev_subset.png

31.2 KiB

results/boston/stage3/weight_density_train-train-dev_subset.png

129 KiB

results/boston/stage4/weight_density_omp_with_params.png

81.3 KiB

results/breast_cancer/stage4/weight_density_omp_with_params.png

39.3 KiB

results/california_housing/stage4/weight_density_omp_with_params.png

378 KiB

results/diabetes/stage1/weight_density_omp_with_params.png

75.8 KiB

results/diabetes/stage1/weight_density_omp_wo_params.png

75.8 KiB

results/diabetes/stage3/weight_density_train-dev_subset.png

84.9 KiB

results/diabetes/stage3/weight_density_train-dev_train-dev_subset.png

332 KiB

results/diabetes/stage3/weight_density_train-train-dev_subset.png

286 KiB

results/diabetes/stage4/weight_density_omp_with_params.png

332 KiB

results/diamonds/stage3/weight_density_train-dev_subset.png

470 KiB

results/diamonds/stage3/weight_density_train-dev_train-dev_subset.png

304 KiB

results/diamonds/stage3/weight_density_train-train-dev_subset.png

370 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment