From 61d553a5570ac522e035acf05695370c804e9b95 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Tue, 4 Apr 2023 10:24:11 -0400
Subject: [PATCH] Feature importance correction

---
 .../result_analysis/feature_importances.py    | 29 ++++++++-----------
 .../multiview_platform/utils/compression.py   | 18 +++++++-----
 2 files changed, 23 insertions(+), 24 deletions(-)

diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py
index f09cc060..d95fb6e4 100644
--- a/summit/multiview_platform/result_analysis/feature_importances.py
+++ b/summit/multiview_platform/result_analysis/feature_importances.py
@@ -25,13 +25,11 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
     """
     feature_importances = {}
     for classifier_result in result:
-        print(classifier_result.classifier_name)
         if isinstance(classifier_result, MonoviewResult):
             if classifier_result.view_name not in feature_importances:
                 feature_importances[classifier_result.view_name] = pd.DataFrame(
                     index=feature_ids[classifier_result.view_index])
             if hasattr(classifier_result.clf, 'feature_importances_'):
-                print(classifier_result.classifier_name)
                 feature_importances[classifier_result.view_name][
                     classifier_result.classifier_name] = classifier_result.clf.feature_importances_
             else:
@@ -64,7 +62,6 @@ def publish_feature_importances(feature_importances, directory, database_name,
         os.mkdir(os.path.join(directory, "feature_importances"))
     for view_name, feature_importance in feature_importances.items():
         if view_name!="mv":
-
             if feature_stds is not None:
                 feature_std = feature_stds[view_name]
             else:
@@ -75,37 +72,35 @@ def publish_feature_importances(feature_importances, directory, database_name,
 
 
             importance_dfs.append(feature_importance.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_importance.index)])).fillna(0))
-            # importance_dfs.append(pd.DataFrame(index=[view_name+"-br"],
-            #                                    columns=feature_importance.columns,
-            #                                    data=np.zeros((1, len(
-            #                                        feature_importance.columns)))))
             std_dfs.append(feature_std.set_index(pd.Index([view_name+"-"+ind
                                                            for ind
                                                            in list(feature_std.index)])).fillna(0))
-            # std_dfs.append(pd.DataFrame(index=[view_name + "-br"],
-            #                                    columns=feature_std.columns,
-            #                                    data=np.zeros((1, len(
-            #                                        feature_std.columns)))))
     if "mv" in feature_importances:
         importance_dfs.append(feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :].fillna(0))
     if len(importance_dfs)>0:
-        print(importance_dfs)
         indices=None
+        columns = None
         for df in importance_dfs:
             if indices is None:
                 indices = list(df.index)
             else:
                 indices += [ind for ind in df.index if ind not in indices]
-        feat_imp_df = pd.DataFrame(index=indices)
-        feature_importances_df = pd.concat([feat_imp_df]+importance_dfs, axis=1)
-        print(feature_importances_df)
+            if columns is None:
+                columns = list(df.columns)
+            else:
+                columns += [col for col in df.columns if col not in columns]
+        feature_importances_df = pd.DataFrame(index=indices, columns=columns)
+        for df in importance_dfs:
+            feature_importances_df = feature_importances_df.combine_first(df)
+
         feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0)
+        feature_std_df = pd.DataFrame(index=indices, columns=columns)
         if len(std_dfs)>0:
-            feature_std_df = pd.concat(std_dfs)
+            for df in std_dfs:
+                feature_std_df = feature_std_df.combine_first(df)
         else:
             feature_std_df = pd.DataFrame()
         if "mv" in feature_importances:
-            # feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
             if feature_stds is not None:
                 feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0)
             else:
diff --git a/summit/multiview_platform/utils/compression.py b/summit/multiview_platform/utils/compression.py
index 0a11b895..7ea525f8 100644
--- a/summit/multiview_platform/utils/compression.py
+++ b/summit/multiview_platform/utils/compression.py
@@ -43,10 +43,14 @@ def remove_compressed(exp_path):
 
 
 if __name__=="__main__":
-    # for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
-    #     print(dir)
-    #     for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
-    #         print("\t", exp)
-    #         explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
-    explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
-    # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
+    for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
+        print(dir)
+        for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
+            print("\t", exp)
+            if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
+                explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
+    # # explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
+    # explore_files("/home/baptiste/Documents/Gitwork/summit/results/tnbc_mazid/debug_started_2023_03_24-11_27_46_thesis")
+    # explore_files(
+    #     "/home/baptiste/Documents/Gitwork/summit/results/lives_thesis_EMF/debug_started_2023_03_24-10_02_21_thesis")
+    # # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
-- 
GitLab