From a99d603146a84999240c507597cff243fed2e743 Mon Sep 17 00:00:00 2001 From: Kossi KOSSIVI <kossi.kossivi@etu.univ-amu.fr> Date: Wed, 7 May 2025 16:22:02 +0200 Subject: [PATCH] Adding feature_importances_ attr to Mumbo and weighted_linear_early_fusion --- requirements.txt | 2 +- summit/multiview_platform/multiview_classifiers/mumbo.py | 3 +++ .../multiview_classifiers/weighted_linear_early_fusion.py | 2 ++ .../multiview_platform/result_analysis/feature_importances.py | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5fbfaac9..a9e891d8 100755 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ plotly>=4.2.1 matplotlib>=3.1.1 tabulate>=0.8.6 pyscm-ml>=1.0.0 -imbalanced-learn \ No newline at end of file +imbalanced-learn>=0.10.1 \ No newline at end of file diff --git a/summit/multiview_platform/multiview_classifiers/mumbo.py b/summit/multiview_platform/multiview_classifiers/mumbo.py index b3933cba..04d241ad 100644 --- a/summit/multiview_platform/multiview_classifiers/mumbo.py +++ b/summit/multiview_platform/multiview_classifiers/mumbo.py @@ -90,6 +90,9 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): np.savetxt(os.path.join(directory, "feature_importances", base_file_name + view_name + "-feature_importances.csv"), feature_importances, delimiter=',') + # CHANGE: Making self.feature_importances_ one array, so he can be easy to use in + # summit.multiview_platform.result_analysis.feature_importances.get_feature_importances + self.feature_importances_ = np.concatenate(self.feature_importances_) self.view_importances /= np.sum(self.view_importances) np.savetxt(os.path.join(directory, base_file_name + "view_importances.csv"), self.view_importances, delimiter=',') diff --git a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py index 9af01836..c131e9a9 100644 --- a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py +++ b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py @@ -65,6 +65,8 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): y=y[train_indices]) self.monoview_classifier.fit(X, y[train_indices]) self.monoview_classifier_config = self.monoview_classifier.get_params() + if hasattr(self.monoview_classifier, 'feature_importances_'): + self.feature_importances_ = self.monoview_classifier.feature_importances_ return self def predict(self, X, sample_indices=None, view_indices=None): diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index 735455ab..36c0eb35 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -44,7 +44,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): v_feature_id] feature_importances["mv"] = pd.DataFrame(index=feat_ids) if hasattr(classifier_result.clf, 'feature_importances_'): - feature_importances["mv"][classifier_result.classifier_name] = np.concatenate(classifier_result.clf.feature_importances_) + feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_ return feature_importances -- GitLab