diff --git a/requirements.txt b/requirements.txt index 5fbfaac93c6434d6879ed9f740eeaa886ce2701b..a9e891d8feb344d2919320d705d946056f27d536 100755 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ plotly>=4.2.1 matplotlib>=3.1.1 tabulate>=0.8.6 pyscm-ml>=1.0.0 -imbalanced-learn \ No newline at end of file +imbalanced-learn>=0.10.1 \ No newline at end of file diff --git a/summit/multiview_platform/multiview_classifiers/mumbo.py b/summit/multiview_platform/multiview_classifiers/mumbo.py index b3933cbaafccec4f62b70208779d46897549ade9..04d241ad29a87d1b24bdff319f1e7145d74c87cf 100644 --- a/summit/multiview_platform/multiview_classifiers/mumbo.py +++ b/summit/multiview_platform/multiview_classifiers/mumbo.py @@ -90,6 +90,9 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): np.savetxt(os.path.join(directory, "feature_importances", base_file_name + view_name + "-feature_importances.csv"), feature_importances, delimiter=',') + # CHANGE: Making self.feature_importances_ one array, so he can be easy to use in + # summit.multiview_platform.result_analysis.feature_importances.get_feature_importances + self.feature_importances_ = np.concatenate(self.feature_importances_) self.view_importances /= np.sum(self.view_importances) np.savetxt(os.path.join(directory, base_file_name + "view_importances.csv"), self.view_importances, delimiter=',') diff --git a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py index 9af0183658e2ebbba32f4c894d1d6fffb4bcf762..c131e9a9a94ab1f768a9fb2958365c9cc3d7076c 100644 --- a/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py +++ b/summit/multiview_platform/multiview_classifiers/weighted_linear_early_fusion.py @@ -65,6 +65,8 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): y=y[train_indices]) self.monoview_classifier.fit(X, y[train_indices]) self.monoview_classifier_config = self.monoview_classifier.get_params() + if hasattr(self.monoview_classifier, 'feature_importances_'): + self.feature_importances_ = self.monoview_classifier.feature_importances_ return self def predict(self, X, sample_indices=None, view_indices=None): diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index 735455abab7676bd95c9e8d821c7042b028ef6a0..36c0eb3514b0fa3db388af10803b60f2f245f011 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -44,7 +44,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): v_feature_id] feature_importances["mv"] = pd.DataFrame(index=feat_ids) if hasattr(classifier_result.clf, 'feature_importances_'): - feature_importances["mv"][classifier_result.classifier_name] = np.concatenate(classifier_result.clf.feature_importances_) + feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_ return feature_importances