Skip to content
Snippets Groups Projects

Resolve "Experiment pipeline"

Merged Charly Lamothe requested to merge 12-experiment-pipeline into master
3 files
+ 36
25
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -57,10 +57,14 @@ class Plotter(object):
ax.plot(x_value, mean, c=color_mean, label=label)
@staticmethod
def plot_losses(file_path, all_experiment_scores_1, all_experiment_scores_2, x_value, xlabel, ylabel, all_labels, title):
fig, axes = plt.subplots(nrows=1, ncols=2)
def plot_stage1_losses(file_path, all_experiment_scores_with_params,
all_experiment_scores_wo_params, x_value, xlabel, ylabel, all_labels, title):
fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True)
n = len(all_experiment_scores_1)
n = len(all_experiment_scores_with_params)
if n != len(all_experiment_scores_wo_params):
raise ValueError('all_experiment_scores_with_params and all_experiment_scores_wo_params must have the same len to be compared.')
"""
Get as many different colors from the specified cmap (here nipy_spectral)
@@ -68,7 +72,8 @@ class Plotter(object):
"""
colors = Plotter.get_colors_from_cmap(n)
for j, all_experiment_scores in enumerate([all_experiment_scores_1, all_experiment_scores_2]):
for j, all_experiment_scores in enumerate([all_experiment_scores_with_params,
all_experiment_scores_wo_params]):
# For each curve to plot
for i in range(n):
# Retreive the scores in a list for each seed
@@ -88,11 +93,13 @@ class Plotter(object):
label=all_labels[i]
)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend(loc='upper right')
fig.savefig(file_path, dpi=fig.dpi)
axes[0].set_xlabel(xlabel)
axes[1].set_xlabel(xlabel)
axes[0].set_ylabel(ylabel)
plt.suptitle(title)
handles, labels = axes[0].get_legend_handles_labels()
legend = axes[0].legend(handles, labels, loc='upper center', bbox_to_anchor=(1.1, -0.15))
fig.savefig(file_path, dpi=fig.dpi, bbox_extra_artists=(legend,), bbox_inches='tight')
plt.close(fig)
@staticmethod
Loading