From 2ca756c48fcf6525da8f0981b78540c9af9fcd61 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 22 Jun 2022 07:45:29 -0400
Subject: [PATCH] Cuisine relevant modifs

---
 summit/multiview_platform/exec_classif.py          |  6 +++---
 .../monoview_classifiers/imbalance_bagging.py      | 12 +++++++++++-
 .../result_analysis/feature_importances.py         |  3 ++-
 summit/multiview_platform/utils/dataset.py         | 14 +++++++++-----
 4 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py
index f742a7fd..11697f4b 100644
--- a/summit/multiview_platform/exec_classif.py
+++ b/summit/multiview_platform/exec_classif.py
@@ -640,9 +640,9 @@ def exec_classif(arguments):  # pragma: no cover
         k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"],
                                         stats_iter_random_states)
 
-        dataset_files = dataset.init_multiple_datasets(args["pathf"],
-                                                       args["name"],
-                                                       nb_cores)
+        # dataset_files = dataset.init_multiple_datasets(args["pathf"],
+        #                                                args["name"],
+        #                                                nb_cores)
 
         views, views_indices, all_views = execution.init_views(dataset_var,
                                                                args[
diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
index c4340420..9dfa2e26 100644
--- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
+++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
@@ -1,5 +1,6 @@
 from imblearn.ensemble import BalancedBaggingClassifier
-from sklearn.tree import DecisionTreeClassifier
+import numpy as np
+
 
 from ..monoview.monoview_utils import BaseMonoviewClassifier
 from ..utils.base import base_boosting_estimators
@@ -27,5 +28,14 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
         self.weird_strings = {"base_estimator": "class_name"}
         self.base_estimator_config = base_estimator_config
 
+    def fit(self, X, y):
+        BalancedBaggingClassifier.fit(self, X, y)
+        self.feature_importances_ = np.zeros(X.shape[1])
+        for estim in self.estimators_:
+            if hasattr(estim['classifier'], 'feature_importances_'):
+                self.feature_importances_ += estim['classifier'].feature_importances_
+        self.feature_importances_ /= np.sum(self.feature_importances_)
+        return self
+
 
 
diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py
index 36c0eb35..042e4c0d 100644
--- a/summit/multiview_platform/result_analysis/feature_importances.py
+++ b/summit/multiview_platform/result_analysis/feature_importances.py
@@ -140,7 +140,8 @@ def plot_feature_relevance(file_name, feature_importance,
             for score in score_df.columns:
                 if len(score.split("-"))>1:
                     algo, view = score.split("-")
-                    feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test']
+                    list_ind = [ind for ind in feature_importance.index if ind.startswith(view)]
+                    feature_importance[algo].loc[list_ind]*=2*(score_df[score]['test']-0.5)
                 else:
                     feature_importance[score] *= score_df[score]['test']
     file_name+="_relevance"
diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py
index 15175976..600a0661 100644
--- a/summit/multiview_platform/utils/dataset.py
+++ b/summit/multiview_platform/utils/dataset.py
@@ -458,11 +458,11 @@ class HDF5Dataset(Dataset):
         for view_index in range(self.nb_view):
             if "feature_ids-View{}".format(view_index) in self.dataset["Metadata"].keys():
                 self.feature_ids[view_index] = [feature_id.decode()
-                                   if not is_just_number(feature_id.decode())
-                                   else "ID_" + feature_id.decode()
-                                       for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]]
+                                                if not is_just_number(feature_id.decode())
+                                                else "ID_" + feature_id.decode()
+                                                for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]]
             else:
-               self.gen_feat_id(view_index)
+                self.gen_feat_id(view_index)
 
     def get_nb_samples(self):
         """
@@ -503,7 +503,7 @@ class HDF5Dataset(Dataset):
             seleted labels' names
         """
         selected_labels = self.get_labels(sample_indices)
-        if decode:
+        if type(self.dataset["Labels"].attrs["names"][0]) == bytes:
             return [label_name.decode("utf-8")
                     for label, label_name in
                     enumerate(self.dataset["Labels"].attrs["names"])
@@ -619,10 +619,14 @@ class HDF5Dataset(Dataset):
         view_names = self.init_view_names(view_names)
         new_dataset_file["Metadata"].attrs["nbView"] = len(view_names)
         for new_index, view_name in enumerate(view_names):
+            del new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)]
+            new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] = new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])]
+            del new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])]
             self.copy_view(target_dataset=new_dataset_file,
                            source_view_name=view_name,
                            target_view_index=new_index,
                            sample_indices=sample_indices)
+
         new_dataset_file.close()
         self.update_hdf5_dataset(dataset_file_path)
 
-- 
GitLab