From a99d603146a84999240c507597cff243fed2e743 Mon Sep 17 00:00:00 2001
From: Kossi KOSSIVI <kossi.kossivi@etu.univ-amu.fr>
Date: Wed, 7 May 2025 16:22:02 +0200
Subject: [PATCH] Adding feature_importances_ attr to Mumbo and
 weighted_linear_early_fusion

---
 requirements.txt                                               | 2 +-
 summit/multiview_platform/multiview_classifiers/mumbo.py       | 3 +++
 .../multiview_classifiers/weighted_linear_early_fusion.py      | 2 ++
 .../multiview_platform/result_analysis/feature_importances.py  | 2 +-
 4 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 5fbfaac9..a9e891d8 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,4 +12,4 @@ plotly>=4.2.1
 matplotlib>=3.1.1
 tabulate>=0.8.6
 pyscm-ml>=1.0.0
-imbalanced-learn
\ No newline at end of file
+imbalanced-learn>=0.10.1
\ No newline at end of file
diff --git a/summit/multiview_platform/multiview_classifiers/mumbo.py b/summit/multiview_platform/multiview_classifiers/mumbo.py
index b3933cba..04d241ad 100644
--- a/summit/multiview_platform/multiview_classifiers/mumbo.py
+++ b/summit/multiview_platform/multiview_classifiers/mumbo.py
@@ -90,6 +90,9 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
             np.savetxt(os.path.join(directory, "feature_importances",
                                     base_file_name + view_name + "-feature_importances.csv"),
                        feature_importances, delimiter=',')
+        # CHANGE: Making self.feature_importances_ one array, so he can be easy to use in
+        # summit.multiview_platform.result_analysis.feature_importances.get_feature_importances
+        self.feature_importances_ = np.concatenate(self.feature_importances_)
         self.view_importances /= np.sum(self.view_importances)
         np.savetxt(os.path.join(directory, base_file_name + "view_importances.csv"), self.view_importances,
                    delimiter=',')
diff --git a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py
index 9af01836..c131e9a9 100644
--- a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py
+++ b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py
@@ -65,6 +65,8 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
                                                     y=y[train_indices])
         self.monoview_classifier.fit(X, y[train_indices])
         self.monoview_classifier_config = self.monoview_classifier.get_params()
+        if hasattr(self.monoview_classifier, 'feature_importances_'):
+            self.feature_importances_ = self.monoview_classifier.feature_importances_
         return self
 
     def predict(self, X, sample_indices=None, view_indices=None):
diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py
index 735455ab..36c0eb35 100644
--- a/summit/multiview_platform/result_analysis/feature_importances.py
+++ b/summit/multiview_platform/result_analysis/feature_importances.py
@@ -44,7 +44,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
                                  v_feature_id]
                 feature_importances["mv"] = pd.DataFrame(index=feat_ids)
             if hasattr(classifier_result.clf, 'feature_importances_'):
-                feature_importances["mv"][classifier_result.classifier_name] = np.concatenate(classifier_result.clf.feature_importances_)
+                feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_
     return feature_importances
 
 
-- 
GitLab