Skip to content
Snippets Groups Projects

Resolve "Experiment pipeline"

Merged Charly Lamothe requested to merge 12-experiment-pipeline into master
3 files
+ 36
25
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -57,7 +57,56 @@ class Plotter(object):
ax.plot(x_value, mean, c=color_mean, label=label)
@staticmethod
def plot_losses(file_path, all_experiment_scores, x_value, xlabel, ylabel, all_labels, title):
def plot_stage1_losses(file_path, all_experiment_scores_with_params,
all_experiment_scores_wo_params, x_value, xlabel, ylabel, all_labels, title):
fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True)
n = len(all_experiment_scores_with_params)
if n != len(all_experiment_scores_wo_params):
raise ValueError('all_experiment_scores_with_params and all_experiment_scores_wo_params must have the same len to be compared.')
"""
Get as many different colors from the specified cmap (here nipy_spectral)
as there are curve to plot.
"""
colors = Plotter.get_colors_from_cmap(n)
for j, all_experiment_scores in enumerate([all_experiment_scores_with_params,
all_experiment_scores_wo_params]):
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
experiment_scores = list(all_experiment_scores[i].values())
# Compute the mean and the std for the CI
mean_experiment_scores = np.average(experiment_scores, axis=0)
std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI
Plotter.plot_mean_and_CI(
ax=axes[j],
mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores,
x_value=x_value,
color_mean=colors[i],
facecolor=colors[i],
label=all_labels[i]
)
axes[0].set_xlabel(xlabel)
axes[1].set_xlabel(xlabel)
axes[0].set_ylabel(ylabel)
axes[1].set_title(title)
handles, labels = axes[0].get_legend_handles_labels()
legend = axes[0].legend(handles, labels, loc='upper center', bbox_to_anchor=(1.1, -0.15))
fig.savefig(file_path, dpi=fig.dpi, bbox_extra_artists=(legend,), bbox_inches='tight')
plt.close(fig)
@staticmethod
def plot_stage2_losses(file_path, all_experiment_scores, x_value,
xlabel, ylabel, all_labels, title):
fig, ax = plt.subplots()
n = len(all_experiment_scores)
@@ -91,7 +140,7 @@ class Plotter(object):
plt.ylabel(ylabel)
plt.title(title)
plt.legend(loc='upper right')
fig.savefig(file_path, dpi=fig.dpi)
fig.savefig(file_path, dpi=fig.dpi, bbox_inches='tight')
plt.close(fig)
@staticmethod
Loading