diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py index 0d5706bc27cb0745fe065456231b7e3023707ac9..fd990dc7df19d2a86d554a8ee0514a36e37cdf53 100644 --- a/code/bolsonaro/visualization/plotter.py +++ b/code/bolsonaro/visualization/plotter.py @@ -57,10 +57,10 @@ class Plotter(object): ax.plot(x_value, mean, c=color_mean, label=label) @staticmethod - def plot_losses(file_path, all_experiment_scores, x_value, xlabel, ylabel, all_labels, title): - fig, ax = plt.subplots() + def plot_losses(file_path, all_experiment_scores_1, all_experiment_scores_2, x_value, xlabel, ylabel, all_labels, title): + fig, axes = plt.subplots(nrows=1, ncols=2) - n = len(all_experiment_scores) + n = len(len(all_experiment_scores_1)) """ Get as many different colors from the specified cmap (here nipy_spectral) @@ -68,24 +68,25 @@ class Plotter(object): """ colors = Plotter.get_colors_from_cmap(n) - # For each curve to plot - for i in range(n): - # Retreive the scores in a list for each seed - experiment_scores = list(all_experiment_scores[i].values()) - # Compute the mean and the std for the CI - mean_experiment_scores = np.average(experiment_scores, axis=0) - std_experiment_scores = np.std(experiment_scores, axis=0) - # Plot the score curve with the CI - Plotter.plot_mean_and_CI( - ax=ax, - mean=mean_experiment_scores, - lb=mean_experiment_scores + std_experiment_scores, - ub=mean_experiment_scores - std_experiment_scores, - x_value=x_value, - color_mean=colors[i], - facecolor=colors[i], - label=all_labels[i] - ) + for j, all_experiment_scores in enumerate([all_experiment_scores_1, all_experiment_scores_2]): + # For each curve to plot + for i in range(n): + # Retreive the scores in a list for each seed + experiment_scores = list(all_experiment_scores[i].values()) + # Compute the mean and the std for the CI + mean_experiment_scores = np.average(experiment_scores, axis=0) + std_experiment_scores = np.std(experiment_scores, axis=0) + # Plot the score curve with the CI + Plotter.plot_mean_and_CI( + ax=axes[j], + mean=mean_experiment_scores, + lb=mean_experiment_scores + std_experiment_scores, + ub=mean_experiment_scores - std_experiment_scores, + x_value=x_value, + color_mean=colors[i], + facecolor=colors[i], + label=all_labels[i] + ) plt.xlabel(xlabel) plt.ylabel(ylabel) diff --git a/code/compute_results.py b/code/compute_results.py index 40e137133f914d40f886a0c15ddcaa8f5c66b89c..7902b2b4c90f1aa7f36a40c5970687dadee7dc14 100644 --- a/code/compute_results.py +++ b/code/compute_results.py @@ -175,10 +175,13 @@ if __name__ == "__main__": pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) Plotter.plot_losses( - file_path=output_path + os.sep + 'losses_with_params.png', - all_experiment_scores=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, + file_path=output_path + os.sep + 'losses.png', + all_experiment_scores_1=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores], + all_experiment_scores_2=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores, + random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores, + omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores], x_value=with_params_extracted_forest_sizes, xlabel='Number of trees extracted', ylabel='MSE', # TODO: hardcoded @@ -187,19 +190,6 @@ if __name__ == "__main__": 'omp_with_params_train', 'omp_with_params_dev', 'omp_with_params_test'], title='Loss values of {} using the best hyperparams'.format(args.dataset_name) ) - Plotter.plot_losses( - file_path=output_path + os.sep + 'losses_wo_params.png', - all_experiment_scores=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores, - random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores, - omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores], - x_value=wo_params_extracted_forest_sizes, - xlabel='Number of trees extracted', - ylabel='MSE', # TODO: hardcoded - all_labels=['base_wo_params_train', 'base_wo_params_dev', 'base_wo_params_test', - 'random_wo_params_train', 'random_wo_params_dev', 'random_wo_params_test', - 'omp_wo_params_train', 'omp_wo_params_dev', 'omp_wo_params_test'], - title='Loss values of {} without using the best hyperparams'.format(args.dataset_name) - ) else: raise ValueError('This stage number is not supported yet, but it will be!')