From d0e487aa9c8059b71a2709f1e4e38d09192ce7fd Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 1 Feb 2023 09:20:18 -0500
Subject: [PATCH] Feat importance

---
 .../result_analysis/feature_importances.py           |  6 +++++-
 summit/multiview_platform/utils/compression.py       | 12 +++++++-----
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py
index 9a515277..50e5e2c6 100644
--- a/summit/multiview_platform/result_analysis/feature_importances.py
+++ b/summit/multiview_platform/result_analysis/feature_importances.py
@@ -31,8 +31,13 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
                     index=feature_ids[classifier_result.view_index])
             if hasattr(classifier_result.clf, 'feature_importances_'):
                 print(classifier_result.classifier_name, classifier_result.view_name)
+
                 feature_importances[classifier_result.view_name][
                     classifier_result.classifier_name] = classifier_result.clf.feature_importances_
+                print(classifier_result.clf.feature_importances_.shape,
+                      feature_importances[classifier_result.view_name][
+                          classifier_result.classifier_name].shape)
+
             else:
                 feature_importances[classifier_result.view_name][
                     classifier_result.classifier_name] = np.zeros(
@@ -149,7 +154,6 @@ def plot_feature_relevance(file_name, feature_importance,
             if isinstance(score_df, dict):
                 score_df = score_df["mean"]
             for score in score_df.columns:
-                print(score)
                 if len(score.split("-"))>1:
                     algo, view = score.split("-")
                     list_ind = [ind for ind in feature_importance.index if ind.startswith(view)]
diff --git a/summit/multiview_platform/utils/compression.py b/summit/multiview_platform/utils/compression.py
index 4af07e0a..24b0ebc7 100644
--- a/summit/multiview_platform/utils/compression.py
+++ b/summit/multiview_platform/utils/compression.py
@@ -43,9 +43,11 @@ def remove_compressed(exp_path):
 
 
 if __name__=="__main__":
-    for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
-        print(dir)
-        for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
-            print("\t", exp)
-            explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
+    # for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
+    #     print(dir)
+    #     for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
+    #         print("\t", exp)
+    #         explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
+    explore_files(
+        os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", "mage_dset", "debug_started_2022_12_13-10_15_20_th"))
     # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
-- 
GitLab