Skip to content
Snippets Groups Projects

Resolve "Experiment pipeline"

Merged Charly Lamothe requested to merge 12-experiment-pipeline into master
2 files
+ 27
36
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -57,10 +57,10 @@ class Plotter(object):
ax.plot(x_value, mean, c=color_mean, label=label)
@staticmethod
def plot_losses(file_path, all_experiment_scores, x_value, xlabel, ylabel, all_labels, title):
fig, ax = plt.subplots()
def plot_losses(file_path, all_experiment_scores_1, all_experiment_scores_2, x_value, xlabel, ylabel, all_labels, title):
fig, axes = plt.subplots(nrows=1, ncols=2)
n = len(all_experiment_scores)
n = len(len(all_experiment_scores_1))
"""
Get as many different colors from the specified cmap (here nipy_spectral)
@@ -68,24 +68,25 @@ class Plotter(object):
"""
colors = Plotter.get_colors_from_cmap(n)
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
experiment_scores = list(all_experiment_scores[i].values())
# Compute the mean and the std for the CI
mean_experiment_scores = np.average(experiment_scores, axis=0)
std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI
Plotter.plot_mean_and_CI(
ax=ax,
mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores,
x_value=x_value,
color_mean=colors[i],
facecolor=colors[i],
label=all_labels[i]
)
for j, all_experiment_scores in enumerate([all_experiment_scores_1, all_experiment_scores_2]):
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
experiment_scores = list(all_experiment_scores[i].values())
# Compute the mean and the std for the CI
mean_experiment_scores = np.average(experiment_scores, axis=0)
std_experiment_scores = np.std(experiment_scores, axis=0)
# Plot the score curve with the CI
Plotter.plot_mean_and_CI(
ax=axes[j],
mean=mean_experiment_scores,
lb=mean_experiment_scores + std_experiment_scores,
ub=mean_experiment_scores - std_experiment_scores,
x_value=x_value,
color_mean=colors[i],
facecolor=colors[i],
label=all_labels[i]
)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
Loading