diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py index f742a7fd7eb726989d6ccaadc779fad15bffe8e9..11697f4b998b3548093484c6af638cbb4d7ed94d 100644 --- a/summit/multiview_platform/exec_classif.py +++ b/summit/multiview_platform/exec_classif.py @@ -640,9 +640,9 @@ def exec_classif(arguments): # pragma: no cover k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"], stats_iter_random_states) - dataset_files = dataset.init_multiple_datasets(args["pathf"], - args["name"], - nb_cores) + # dataset_files = dataset.init_multiple_datasets(args["pathf"], + # args["name"], + # nb_cores) views, views_indices, all_views = execution.init_views(dataset_var, args[ diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py index c434042044e96cd9cb5947b70aee3c4ace77647f..9dfa2e263a5ad0e54ecd7027c97fe423204749f5 100644 --- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -1,5 +1,6 @@ from imblearn.ensemble import BalancedBaggingClassifier -from sklearn.tree import DecisionTreeClassifier +import numpy as np + from ..monoview.monoview_utils import BaseMonoviewClassifier from ..utils.base import base_boosting_estimators @@ -27,5 +28,14 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): self.weird_strings = {"base_estimator": "class_name"} self.base_estimator_config = base_estimator_config + def fit(self, X, y): + BalancedBaggingClassifier.fit(self, X, y) + self.feature_importances_ = np.zeros(X.shape[1]) + for estim in self.estimators_: + if hasattr(estim['classifier'], 'feature_importances_'): + self.feature_importances_ += estim['classifier'].feature_importances_ + self.feature_importances_ /= np.sum(self.feature_importances_) + return self + diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index 36c0eb3514b0fa3db388af10803b60f2f245f011..042e4c0d4d744e1681c1fc4e965892adc024ce61 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -140,7 +140,8 @@ def plot_feature_relevance(file_name, feature_importance, for score in score_df.columns: if len(score.split("-"))>1: algo, view = score.split("-") - feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test'] + list_ind = [ind for ind in feature_importance.index if ind.startswith(view)] + feature_importance[algo].loc[list_ind]*=2*(score_df[score]['test']-0.5) else: feature_importance[score] *= score_df[score]['test'] file_name+="_relevance" diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py index 151759765792fe5920a7289eed5207ce5aa74aef..600a06618169570e18e400901742f303c53f3402 100644 --- a/summit/multiview_platform/utils/dataset.py +++ b/summit/multiview_platform/utils/dataset.py @@ -458,11 +458,11 @@ class HDF5Dataset(Dataset): for view_index in range(self.nb_view): if "feature_ids-View{}".format(view_index) in self.dataset["Metadata"].keys(): self.feature_ids[view_index] = [feature_id.decode() - if not is_just_number(feature_id.decode()) - else "ID_" + feature_id.decode() - for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]] + if not is_just_number(feature_id.decode()) + else "ID_" + feature_id.decode() + for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]] else: - self.gen_feat_id(view_index) + self.gen_feat_id(view_index) def get_nb_samples(self): """ @@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): seleted labels' names """ selected_labels = self.get_labels(sample_indices) - if decode: + if type(self.dataset["Labels"].attrs["names"][0]) == bytes: return [label_name.decode("utf-8") for label, label_name in enumerate(self.dataset["Labels"].attrs["names"]) @@ -619,10 +619,14 @@ class HDF5Dataset(Dataset): view_names = self.init_view_names(view_names) new_dataset_file["Metadata"].attrs["nbView"] = len(view_names) for new_index, view_name in enumerate(view_names): + del new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] + new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] = new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])] + del new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])] self.copy_view(target_dataset=new_dataset_file, source_view_name=view_name, target_view_index=new_index, sample_indices=sample_indices) + new_dataset_file.close() self.update_hdf5_dataset(dataset_file_path)