diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py index e82743db1386b1aa7a51d46fbc29818794ce64cb..586156c10d0ccc8ee4dc53c91c6e80c6477ae4e6 100644 --- a/code/bolsonaro/visualization/plotter.py +++ b/code/bolsonaro/visualization/plotter.py @@ -96,7 +96,7 @@ class Plotter(object): axes[0].set_xlabel(xlabel) axes[1].set_xlabel(xlabel) axes[0].set_ylabel(ylabel) - plt.suptitle(title) + axes[1].set_title(title) handles, labels = axes[0].get_legend_handles_labels() legend = axes[0].legend(handles, labels, loc='upper center', bbox_to_anchor=(1.1, -0.15)) fig.savefig(file_path, dpi=fig.dpi, bbox_extra_artists=(legend,), bbox_inches='tight') diff --git a/code/compute_results.py b/code/compute_results.py index 2db767586436f1d7e0fe44209fab953e80da91d6..fd04b4e10acb63334e21c42bc876977dec1bf7a6 100644 --- a/code/compute_results.py +++ b/code/compute_results.py @@ -192,7 +192,7 @@ if __name__ == "__main__": x_value=with_params_extracted_forest_sizes, xlabel='Number of trees extracted', ylabel='MSE', # TODO: hardcoded - title='Loss values of {} using best and default hyperparameters'.format(args.dataset_name) + title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name) ) else: raise ValueError('This stage number is not supported yet, but it will be!') diff --git a/results/california_housing/stage1/losses.png b/results/california_housing/stage1/losses.png index 902c5b74488875024b925388edac21bd749c10b0..670e0881812fc00539637cec6068b7300f871e3c 100644 Binary files a/results/california_housing/stage1/losses.png and b/results/california_housing/stage1/losses.png differ