From 2ca756c48fcf6525da8f0981b78540c9af9fcd61 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Wed, 22 Jun 2022 07:45:29 -0400 Subject: [PATCH] Cuisine relevant modifs --- summit/multiview_platform/exec_classif.py | 6 +++--- .../monoview_classifiers/imbalance_bagging.py | 12 +++++++++++- .../result_analysis/feature_importances.py | 3 ++- summit/multiview_platform/utils/dataset.py | 14 +++++++++----- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py index f742a7fd..11697f4b 100644 --- a/summit/multiview_platform/exec_classif.py +++ b/summit/multiview_platform/exec_classif.py @@ -640,9 +640,9 @@ def exec_classif(arguments): # pragma: no cover k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"], stats_iter_random_states) - dataset_files = dataset.init_multiple_datasets(args["pathf"], - args["name"], - nb_cores) + # dataset_files = dataset.init_multiple_datasets(args["pathf"], + # args["name"], + # nb_cores) views, views_indices, all_views = execution.init_views(dataset_var, args[ diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py index c4340420..9dfa2e26 100644 --- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -1,5 +1,6 @@ from imblearn.ensemble import BalancedBaggingClassifier -from sklearn.tree import DecisionTreeClassifier +import numpy as np + from ..monoview.monoview_utils import BaseMonoviewClassifier from ..utils.base import base_boosting_estimators @@ -27,5 +28,14 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): self.weird_strings = {"base_estimator": "class_name"} self.base_estimator_config = base_estimator_config + def fit(self, X, y): + BalancedBaggingClassifier.fit(self, X, y) + self.feature_importances_ = np.zeros(X.shape[1]) + for estim in self.estimators_: + if hasattr(estim['classifier'], 'feature_importances_'): + self.feature_importances_ += estim['classifier'].feature_importances_ + self.feature_importances_ /= np.sum(self.feature_importances_) + return self + diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index 36c0eb35..042e4c0d 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -140,7 +140,8 @@ def plot_feature_relevance(file_name, feature_importance, for score in score_df.columns: if len(score.split("-"))>1: algo, view = score.split("-") - feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test'] + list_ind = [ind for ind in feature_importance.index if ind.startswith(view)] + feature_importance[algo].loc[list_ind]*=2*(score_df[score]['test']-0.5) else: feature_importance[score] *= score_df[score]['test'] file_name+="_relevance" diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py index 15175976..600a0661 100644 --- a/summit/multiview_platform/utils/dataset.py +++ b/summit/multiview_platform/utils/dataset.py @@ -458,11 +458,11 @@ class HDF5Dataset(Dataset): for view_index in range(self.nb_view): if "feature_ids-View{}".format(view_index) in self.dataset["Metadata"].keys(): self.feature_ids[view_index] = [feature_id.decode() - if not is_just_number(feature_id.decode()) - else "ID_" + feature_id.decode() - for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]] + if not is_just_number(feature_id.decode()) + else "ID_" + feature_id.decode() + for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]] else: - self.gen_feat_id(view_index) + self.gen_feat_id(view_index) def get_nb_samples(self): """ @@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): seleted labels' names """ selected_labels = self.get_labels(sample_indices) - if decode: + if type(self.dataset["Labels"].attrs["names"][0]) == bytes: return [label_name.decode("utf-8") for label, label_name in enumerate(self.dataset["Labels"].attrs["names"]) @@ -619,10 +619,14 @@ class HDF5Dataset(Dataset): view_names = self.init_view_names(view_names) new_dataset_file["Metadata"].attrs["nbView"] = len(view_names) for new_index, view_name in enumerate(view_names): + del new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] + new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] = new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])] + del new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])] self.copy_view(target_dataset=new_dataset_file, source_view_name=view_name, target_view_index=new_index, sample_indices=sample_indices) + new_dataset_file.close() self.update_hdf5_dataset(dataset_file_path) -- GitLab