diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py index 86a906daa1576943f98edc9b73b965d2b2bac608..0d5706bc27cb0745fe065456231b7e3023707ac9 100644 --- a/code/bolsonaro/visualization/plotter.py +++ b/code/bolsonaro/visualization/plotter.py @@ -7,20 +7,42 @@ class Plotter(object): @staticmethod def weight_density(all_experiment_weights, file_path): - ''' + """ Function that creates the figure with the density of the weights :param all_experiment_weights: The weights for the different experiments :param file path: str, path where the figure will be saved - ''' + TODO: colored by seed number or not? + TODO: represents both the seed AND the extracted tree information in the legend + """ + """ + Convert dictionnary of structure + {seed_1:[M x W]], seed_k:[M x W]} + to numpy.ndarray with dim [K x M x W] + where K is the seed number, M is the + number of extracted trees and W the + weight number. + """ all_experiment_weights = np.array(list(all_experiment_weights.values())) + n = len(all_experiment_weights) + + """ + Get as many different colors from the specified cmap (here nipy_spectral) + as there are seeds used. + """ colors = Plotter.get_colors_from_cmap(n) fig, ax = plt.subplots() + # For each seed for i in range(n): + # For each weight set of a given extracted tree number for weights in all_experiment_weights[i]: - pd.Series([weight for weight in weights if weight != 0]).plot.kde( + """ + Plot the series of weights that aren't zero, + colored by seed number. + """ + pd.Series(weights[np.nonzero(weights)]).plot.kde( figsize=(15, 10), ax=ax, color=colors[i]) ax.set_title('Density weights of the OMP')