Skip to content
Snippets Groups Projects
Commit 2ca756c4 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Cuisine relevant modifs

parent 429c7c0d
Branches
No related tags found
No related merge requests found
...@@ -640,9 +640,9 @@ def exec_classif(arguments): # pragma: no cover ...@@ -640,9 +640,9 @@ def exec_classif(arguments): # pragma: no cover
k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"], k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"],
stats_iter_random_states) stats_iter_random_states)
dataset_files = dataset.init_multiple_datasets(args["pathf"], # dataset_files = dataset.init_multiple_datasets(args["pathf"],
args["name"], # args["name"],
nb_cores) # nb_cores)
views, views_indices, all_views = execution.init_views(dataset_var, views, views_indices, all_views = execution.init_views(dataset_var,
args[ args[
......
from imblearn.ensemble import BalancedBaggingClassifier from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.tree import DecisionTreeClassifier import numpy as np
from ..monoview.monoview_utils import BaseMonoviewClassifier from ..monoview.monoview_utils import BaseMonoviewClassifier
from ..utils.base import base_boosting_estimators from ..utils.base import base_boosting_estimators
...@@ -27,5 +28,14 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): ...@@ -27,5 +28,14 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
self.weird_strings = {"base_estimator": "class_name"} self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config self.base_estimator_config = base_estimator_config
def fit(self, X, y):
BalancedBaggingClassifier.fit(self, X, y)
self.feature_importances_ = np.zeros(X.shape[1])
for estim in self.estimators_:
if hasattr(estim['classifier'], 'feature_importances_'):
self.feature_importances_ += estim['classifier'].feature_importances_
self.feature_importances_ /= np.sum(self.feature_importances_)
return self
...@@ -140,7 +140,8 @@ def plot_feature_relevance(file_name, feature_importance, ...@@ -140,7 +140,8 @@ def plot_feature_relevance(file_name, feature_importance,
for score in score_df.columns: for score in score_df.columns:
if len(score.split("-"))>1: if len(score.split("-"))>1:
algo, view = score.split("-") algo, view = score.split("-")
feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test'] list_ind = [ind for ind in feature_importance.index if ind.startswith(view)]
feature_importance[algo].loc[list_ind]*=2*(score_df[score]['test']-0.5)
else: else:
feature_importance[score] *= score_df[score]['test'] feature_importance[score] *= score_df[score]['test']
file_name+="_relevance" file_name+="_relevance"
......
...@@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): ...@@ -503,7 +503,7 @@ class HDF5Dataset(Dataset):
seleted labels' names seleted labels' names
""" """
selected_labels = self.get_labels(sample_indices) selected_labels = self.get_labels(sample_indices)
if decode: if type(self.dataset["Labels"].attrs["names"][0]) == bytes:
return [label_name.decode("utf-8") return [label_name.decode("utf-8")
for label, label_name in for label, label_name in
enumerate(self.dataset["Labels"].attrs["names"]) enumerate(self.dataset["Labels"].attrs["names"])
...@@ -619,10 +619,14 @@ class HDF5Dataset(Dataset): ...@@ -619,10 +619,14 @@ class HDF5Dataset(Dataset):
view_names = self.init_view_names(view_names) view_names = self.init_view_names(view_names)
new_dataset_file["Metadata"].attrs["nbView"] = len(view_names) new_dataset_file["Metadata"].attrs["nbView"] = len(view_names)
for new_index, view_name in enumerate(view_names): for new_index, view_name in enumerate(view_names):
del new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)]
new_dataset_file["Metadata"]["feature_ids-View{}".format(new_index)] = new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])]
del new_dataset_file["Metadata"]["feature_ids-View{}".format(self.view_dict[view_name])]
self.copy_view(target_dataset=new_dataset_file, self.copy_view(target_dataset=new_dataset_file,
source_view_name=view_name, source_view_name=view_name,
target_view_index=new_index, target_view_index=new_index,
sample_indices=sample_indices) sample_indices=sample_indices)
new_dataset_file.close() new_dataset_file.close()
self.update_hdf5_dataset(dataset_file_path) self.update_hdf5_dataset(dataset_file_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment