From eefd7b6b0a530df474da015b8613b7422e48adb6 Mon Sep 17 00:00:00 2001 From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr> Date: Wed, 6 Nov 2019 15:18:20 +0100 Subject: [PATCH] - Fix weight_density. TODO: add legend (colored by seed) - Add a TODO for the subset in train - Remove an useless new line --- code/bolsonaro/trainer.py | 2 ++ code/bolsonaro/visualization/plotter.py | 15 ++++++++++----- code/compute_results.py | 1 - 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 01d0a03..dcc16e0 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -30,6 +30,8 @@ class Trainer(object): y_omp = y_forest self._logger.debug('Fitting both the forest and OMP on train+dev subsets.') + # TODO: add an option to train forest to train+dev and OMP to dev + model.fit( X_forest=X_forest, y_forest=y_forest, diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py index d7ae5a8..86a906d 100644 --- a/code/bolsonaro/visualization/plotter.py +++ b/code/bolsonaro/visualization/plotter.py @@ -12,14 +12,19 @@ class Plotter(object): :param all_experiment_weights: The weights for the different experiments :param file path: str, path where the figure will be saved ''' + + all_experiment_weights = np.array(list(all_experiment_weights.values())) + n = len(all_experiment_weights) + colors = Plotter.get_colors_from_cmap(n) + fig, ax = plt.subplots() - for weights in all_experiment_weights.values(): - pd.Series([weight for weight in weights if weight != 0]).plot.kde(figsize=(15, 10), ax=ax) + for i in range(n): + for weights in all_experiment_weights[i]: + pd.Series([weight for weight in weights if weight != 0]).plot.kde( + figsize=(15, 10), ax=ax, color=colors[i]) - legends = ['Experience ' + str(i+1) for i in range(len(all_experiment_weights))] - ax.legend(legends) + ax.set_title('Density weights of the OMP') fig.savefig(file_path, dpi=fig.dpi) - fig.title('Density weights of the OMP') plt.close(fig) @staticmethod diff --git a/code/compute_results.py b/code/compute_results.py index 616a8d4..78d3027 100644 --- a/code/compute_results.py +++ b/code/compute_results.py @@ -117,7 +117,6 @@ if __name__ == "__main__": ) # Plot the density of the weights - Plotter.weight_density( file_path=args.results_dir + os.sep + experiment_id + os.sep + 'density_weight.png', all_experiment_weights=experiment_weights -- GitLab