Skip to content
Snippets Groups Projects
Commit 2df5560f authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Update plot in compute results to have two axis instead of two plots.

parent 3a2ec5cb
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -57,10 +57,10 @@ class Plotter(object): ...@@ -57,10 +57,10 @@ class Plotter(object):
ax.plot(x_value, mean, c=color_mean, label=label) ax.plot(x_value, mean, c=color_mean, label=label)
@staticmethod @staticmethod
def plot_losses(file_path, all_experiment_scores, x_value, xlabel, ylabel, all_labels, title): def plot_losses(file_path, all_experiment_scores_1, all_experiment_scores_2, x_value, xlabel, ylabel, all_labels, title):
fig, ax = plt.subplots() fig, axes = plt.subplots(nrows=1, ncols=2)
n = len(all_experiment_scores) n = len(len(all_experiment_scores_1))
""" """
Get as many different colors from the specified cmap (here nipy_spectral) Get as many different colors from the specified cmap (here nipy_spectral)
...@@ -68,6 +68,7 @@ class Plotter(object): ...@@ -68,6 +68,7 @@ class Plotter(object):
""" """
colors = Plotter.get_colors_from_cmap(n) colors = Plotter.get_colors_from_cmap(n)
for j, all_experiment_scores in enumerate([all_experiment_scores_1, all_experiment_scores_2]):
# For each curve to plot # For each curve to plot
for i in range(n): for i in range(n):
# Retreive the scores in a list for each seed # Retreive the scores in a list for each seed
...@@ -77,7 +78,7 @@ class Plotter(object): ...@@ -77,7 +78,7 @@ class Plotter(object):
std_experiment_scores = np.std(experiment_scores, axis=0) std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI # Plot the score curve with the CI
Plotter.plot_mean_and_CI( Plotter.plot_mean_and_CI(
ax=ax, ax=axes[j],
mean=mean_experiment_scores, mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores, lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores, ub=mean_experiment_scores - std_experiment_scores,
......
...@@ -175,10 +175,13 @@ if __name__ == "__main__": ...@@ -175,10 +175,13 @@ if __name__ == "__main__":
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_losses( Plotter.plot_losses(
file_path=output_path + os.sep + 'losses_with_params.png', file_path=output_path + os.sep + 'losses.png',
all_experiment_scores=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, all_experiment_scores_1=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores,
random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores,
omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores], omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores],
all_experiment_scores_2=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores,
random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores,
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores],
x_value=with_params_extracted_forest_sizes, x_value=with_params_extracted_forest_sizes,
xlabel='Number of trees extracted', xlabel='Number of trees extracted',
ylabel='MSE', # TODO: hardcoded ylabel='MSE', # TODO: hardcoded
...@@ -187,19 +190,6 @@ if __name__ == "__main__": ...@@ -187,19 +190,6 @@ if __name__ == "__main__":
'omp_with_params_train', 'omp_with_params_dev', 'omp_with_params_test'], 'omp_with_params_train', 'omp_with_params_dev', 'omp_with_params_test'],
title='Loss values of {} using the best hyperparams'.format(args.dataset_name) title='Loss values of {} using the best hyperparams'.format(args.dataset_name)
) )
Plotter.plot_losses(
file_path=output_path + os.sep + 'losses_wo_params.png',
all_experiment_scores=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores,
random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores,
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores],
x_value=wo_params_extracted_forest_sizes,
xlabel='Number of trees extracted',
ylabel='MSE', # TODO: hardcoded
all_labels=['base_wo_params_train', 'base_wo_params_dev', 'base_wo_params_test',
'random_wo_params_train', 'random_wo_params_dev', 'random_wo_params_test',
'omp_wo_params_train', 'omp_wo_params_dev', 'omp_wo_params_test'],
title='Loss values of {} without using the best hyperparams'.format(args.dataset_name)
)
else: else:
raise ValueError('This stage number is not supported yet, but it will be!') raise ValueError('This stage number is not supported yet, but it will be!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment