Skip to content
Snippets Groups Projects
Commit 2df5560f authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Update plot in compute results to have two axis instead of two plots.

parent 3a2ec5cb
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
......@@ -57,10 +57,10 @@ class Plotter(object):
ax.plot(x_value, mean, c=color_mean, label=label)
@staticmethod
def plot_losses(file_path, all_experiment_scores, x_value, xlabel, ylabel, all_labels, title):
fig, ax = plt.subplots()
def plot_losses(file_path, all_experiment_scores_1, all_experiment_scores_2, x_value, xlabel, ylabel, all_labels, title):
fig, axes = plt.subplots(nrows=1, ncols=2)
n = len(all_experiment_scores)
n = len(len(all_experiment_scores_1))
"""
Get as many different colors from the specified cmap (here nipy_spectral)
......@@ -68,24 +68,25 @@ class Plotter(object):
"""
colors = Plotter.get_colors_from_cmap(n)
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
experiment_scores = list(all_experiment_scores[i].values())
# Compute the mean and the std for the CI
mean_experiment_scores = np.average(experiment_scores, axis=0)
std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI
Plotter.plot_mean_and_CI(
ax=ax,
mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores,
x_value=x_value,
color_mean=colors[i],
facecolor=colors[i],
label=all_labels[i]
)
for j, all_experiment_scores in enumerate([all_experiment_scores_1, all_experiment_scores_2]):
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
experiment_scores = list(all_experiment_scores[i].values())
# Compute the mean and the std for the CI
mean_experiment_scores = np.average(experiment_scores, axis=0)
std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI
Plotter.plot_mean_and_CI(
ax=axes[j],
mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores,
x_value=x_value,
color_mean=colors[i],
facecolor=colors[i],
label=all_labels[i]
)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
......
......@@ -175,10 +175,13 @@ if __name__ == "__main__":
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_losses(
file_path=output_path + os.sep + 'losses_with_params.png',
all_experiment_scores=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores,
file_path=output_path + os.sep + 'losses.png',
all_experiment_scores_1=[base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores,
random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores,
omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores],
all_experiment_scores_2=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores,
random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores,
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores],
x_value=with_params_extracted_forest_sizes,
xlabel='Number of trees extracted',
ylabel='MSE', # TODO: hardcoded
......@@ -187,19 +190,6 @@ if __name__ == "__main__":
'omp_with_params_train', 'omp_with_params_dev', 'omp_with_params_test'],
title='Loss values of {} using the best hyperparams'.format(args.dataset_name)
)
Plotter.plot_losses(
file_path=output_path + os.sep + 'losses_wo_params.png',
all_experiment_scores=[base_wo_params_train_scores, base_wo_params_dev_scores, base_wo_params_test_scores,
random_wo_params_train_scores, random_wo_params_dev_scores, random_wo_params_test_scores,
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores],
x_value=wo_params_extracted_forest_sizes,
xlabel='Number of trees extracted',
ylabel='MSE', # TODO: hardcoded
all_labels=['base_wo_params_train', 'base_wo_params_dev', 'base_wo_params_test',
'random_wo_params_train', 'random_wo_params_dev', 'random_wo_params_test',
'omp_wo_params_train', 'omp_wo_params_dev', 'omp_wo_params_test'],
title='Loss values of {} without using the best hyperparams'.format(args.dataset_name)
)
else:
raise ValueError('This stage number is not supported yet, but it will be!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment