From eefd7b6b0a530df474da015b8613b7422e48adb6 Mon Sep 17 00:00:00 2001
From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr>
Date: Wed, 6 Nov 2019 15:18:20 +0100
Subject: [PATCH] - Fix weight_density. TODO: add legend (colored by seed) -
 Add a TODO for the subset in train - Remove an useless new line

---
 code/bolsonaro/trainer.py               |  2 ++
 code/bolsonaro/visualization/plotter.py | 15 ++++++++++-----
 code/compute_results.py                 |  1 -
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 01d0a03..dcc16e0 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -30,6 +30,8 @@ class Trainer(object):
             y_omp = y_forest
             self._logger.debug('Fitting both the forest and OMP on train+dev subsets.')
 
+        # TODO: add an option to train forest to train+dev and OMP to dev
+
         model.fit(
             X_forest=X_forest,
             y_forest=y_forest,
diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py
index d7ae5a8..86a906d 100644
--- a/code/bolsonaro/visualization/plotter.py
+++ b/code/bolsonaro/visualization/plotter.py
@@ -12,14 +12,19 @@ class Plotter(object):
         :param all_experiment_weights: The weights for the different experiments
         :param file path: str, path where the figure will be saved
         '''
+
+        all_experiment_weights = np.array(list(all_experiment_weights.values()))
+        n = len(all_experiment_weights)
+        colors = Plotter.get_colors_from_cmap(n)
+
         fig, ax = plt.subplots()
-        for weights in all_experiment_weights.values():
-            pd.Series([weight for weight in weights if weight != 0]).plot.kde(figsize=(15, 10), ax=ax)
+        for i in range(n):
+            for weights in all_experiment_weights[i]:
+                pd.Series([weight for weight in weights if weight != 0]).plot.kde(
+                    figsize=(15, 10), ax=ax, color=colors[i])
 
-        legends = ['Experience ' + str(i+1) for i in range(len(all_experiment_weights))]
-        ax.legend(legends)
+        ax.set_title('Density weights of the OMP')
         fig.savefig(file_path, dpi=fig.dpi)
-        fig.title('Density weights of the OMP')
         plt.close(fig)
 
     @staticmethod
diff --git a/code/compute_results.py b/code/compute_results.py
index 616a8d4..78d3027 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -117,7 +117,6 @@ if __name__ == "__main__":
         )
 
         # Plot the density of the weights
-
         Plotter.weight_density(
             file_path=args.results_dir + os.sep + experiment_id + os.sep + 'density_weight.png',
             all_experiment_weights=experiment_weights
-- 
GitLab