From e1d410a58885908249a32af666b8f815bcfe606b Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Tue, 6 Sep 2022 10:13:11 -0400 Subject: [PATCH] Working order --- .../multiview/multiview_utils.py | 1 - .../multiview_classifiers/mucombo.py | 47 ++++++++++- .../multiview_classifiers/mumbo.py | 2 +- .../multiview_classifiers/spkm_pw.py | 79 +++++++++++++++++++ .../result_analysis/feature_importances.py | 7 +- 5 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 summit/multiview_platform/multiview_classifiers/spkm_pw.py diff --git a/summit/multiview_platform/multiview/multiview_utils.py b/summit/multiview_platform/multiview/multiview_utils.py index 032f181a..93a56639 100644 --- a/summit/multiview_platform/multiview/multiview_utils.py +++ b/summit/multiview_platform/multiview/multiview_utils.py @@ -28,7 +28,6 @@ class BaseMultiviewClassifier(BaseClassifier): self.used_views = None def set_base_estim_from_dict(self, base_estim_dict, **kwargs): - print(base_estim_dict) if base_estim_dict is None: base_estimator = DecisionTreeClassifier() elif isinstance(base_estim_dict, str) and kwargs is not None: diff --git a/summit/multiview_platform/multiview_classifiers/mucombo.py b/summit/multiview_platform/multiview_classifiers/mucombo.py index ad5268e6..784b34f8 100644 --- a/summit/multiview_platform/multiview_classifiers/mucombo.py +++ b/summit/multiview_platform/multiview_classifiers/mucombo.py @@ -1,16 +1,17 @@ from sklearn.tree import DecisionTreeClassifier +from sklearn.base import BaseEstimator - +import numpy as np from multimodal.boosting.combo import MuComboClassifier from ..multiview.multiview_utils import BaseMultiviewClassifier from ..utils.hyper_parameter_search import CustomRandint from ..utils.dataset import get_samples_views_indices from ..utils.base import base_boosting_estimators -classifier_class_name = "MuCumbo" +classifier_class_name = "MuCombo" -class MuCumbo(BaseMultiviewClassifier, MuComboClassifier): +class MuCombo(BaseMultiviewClassifier, MuComboClassifier): def __init__(self, base_estimator=None, n_estimators=50, @@ -31,9 +32,33 @@ class MuCumbo(BaseMultiviewClassifier, MuComboClassifier): self.used_views = view_indices numpy_X, view_limits = X.to_numpy_array(sample_indices=train_indices, view_indices=view_indices) + self.view_shapes = [view_limits[ind + 1] - view_limits[ind] + for ind in range(len(self.used_views))] return MuComboClassifier.fit(self, numpy_X, y[train_indices], view_limits) + def set_params(self, base_estimator=None, **params): + """ + Sets the base estimator from a dict. + :param base_estimator: + :param params: + :return: + """ + if base_estimator is None: + self.base_estimator = DecisionTreeClassifier() + elif type(base_estimator) is list: + if type(base_estimator[0]) is dict: + self.base_estimator = [self.set_base_estim_from_dict(estim) for estim in base_estimator] + elif isinstance(base_estimator[0], BaseEstimator): + self.base_estimator = base_estimator + else: + raise ValueError("base_estimator should ba a list of dict or a sklearn classifier list") + elif isinstance(base_estimator, dict): + self.base_estimator = self.set_base_estim_from_dict(base_estimator) + MuComboClassifier.set_params(self, **params) + else: + MuComboClassifier.set_params(self, base_estimator=base_estimator, **params) + def predict(self, X, sample_indices=None, view_indices=None): sample_indices, view_indices = get_samples_views_indices(X, sample_indices, @@ -43,6 +68,22 @@ class MuCumbo(BaseMultiviewClassifier, MuComboClassifier): view_indices=view_indices) return MuComboClassifier.predict(self, numpy_X) + def accepts_multi_class(self, random_state, n_samples=10, dim=2, + n_classes=3, n_views=2): + return True + def get_interpretation(self, directory, base_file_name, labels, multiclass=False): + self.feature_importances_ = [np.zeros(view_shape) + for view_shape in self.view_shapes] + # for best_view, estimator_weight, estimator in zip(self.best_views_, self.estimator_weights_, self.estimators_): + # self.view_importances[best_view] += estimator_weight + # if hasattr(estimator, "feature_importances_"): + # self.feature_importances_[best_view] += estimator.feature_importances_ + # importances_sum = sum([np.sum(feature_importances) + # for feature_importances + # in self.feature_importances_]) + # self.feature_importances_ = [feature_importances/importances_sum + # for feature_importances + # in self.feature_importances_] return "" diff --git a/summit/multiview_platform/multiview_classifiers/mumbo.py b/summit/multiview_platform/multiview_classifiers/mumbo.py index 203228ba..1bb2f3db 100644 --- a/summit/multiview_platform/multiview_classifiers/mumbo.py +++ b/summit/multiview_platform/multiview_classifiers/mumbo.py @@ -72,7 +72,6 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): view_indices=view_indices) self.view_shapes = [view_limits[ind+1]-view_limits[ind] for ind in range(len(self.used_views)) ] - return MumboClassifier.fit(self, numpy_X, y[train_indices], view_limits) @@ -118,6 +117,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): self.view_importances[view_index]) interpret_string +="\n The boosting process selected views : \n" + ", ".join(map(str, self.best_views_)) interpret_string+="\n\n With estimator weights : \n"+ "\n".join(map(str,self.estimator_weights_/np.sum(self.estimator_weights_))) + print(self.feature_importances_) return interpret_string def accepts_multi_class(self, random_state, n_samples=10, dim=2, diff --git a/summit/multiview_platform/multiview_classifiers/spkm_pw.py b/summit/multiview_platform/multiview_classifiers/spkm_pw.py new file mode 100644 index 00000000..4c5a51ad --- /dev/null +++ b/summit/multiview_platform/multiview_classifiers/spkm_pw.py @@ -0,0 +1,79 @@ +import numpy as np +from sklearn.preprocessing import LabelBinarizer + +from spkm.spkm_wrapper import pairwiseSPKMlikeSklearn +from spkm.kernels_and_gradients import RBFKernel, PolyKernel + +from ..multiview.multiview_utils import BaseMultiviewClassifier +from ..utils.hyper_parameter_search import CustomRandint +from ..utils.dataset import get_samples_views_indices + +classifier_class_name = "PWSPKM" + +class PWSPKM(BaseMultiviewClassifier, pairwiseSPKMlikeSklearn): + + def __init__(self, random_state=42, n_u=2, kernel1=RBFKernel(0.5), + kernel2=RBFKernel(0.5), spkmregP=1, spkminit="randn", + nspkminits=10, preprocessinglist=[0,1,2], **kwargs): + BaseMultiviewClassifier.__init__(self, random_state) + + pairwiseSPKMlikeSklearn.__init__(self, random_state=random_state, + n_u=n_u, + kernel1=kernel1, + kernel2=kernel2, + spkmregP=spkmregP, + spkminit=spkminit, + nspkminits=nspkminits, + preprocessinglist=preprocessinglist) + self.param_names = ["n_u", "kernel1", "kernel2", "spkmregP", + "spkminit", "nspkminits", "preprocessinglist", + "random_state"] + self.distribs = [[2], [PolyKernel({"d":3, "r":1})], [PolyKernel({"d":3, "r":1})], CustomRandint(1,15), + ["data", "randn"], CustomRandint(1,30), + [[], [0], [0,1], [0,1,2]], [random_state],] + self.more_than_two_views = False + self.random_state = random_state + + def fit(self, X, y, train_indices=None, view_indices=None): + + self.lb = LabelBinarizer(pos_label=1, neg_label=-1) + y = self.lb.fit_transform(y) + train_indices, view_indices = get_samples_views_indices(X, + train_indices, + view_indices) + if len(view_indices)>2: + self.more_than_two_views = True + self.label_set = np.unique(y) + return self + self.used_views = view_indices + self.view_names = [X.get_view_name(view_index) + for view_index in view_indices] + view_list = [X.get_v(view_index)[train_indices, :] + for view_index in view_indices] + + return pairwiseSPKMlikeSklearn.fit(self, view_list, y[train_indices,0],) + + def predict(self, X, sample_indices=None, view_indices=None): + if self.more_than_two_views: + return self.random_state.choice(self.label_set, replace=True, size=X.shape[0]) + sample_indices, view_indices = get_samples_views_indices(X, + sample_indices, + view_indices) + view_list = [X.get_v(view_index)[sample_indices, :] + for view_index in view_indices] + self._check_views(view_indices) + + view_list = [X.get_v(view_index)[sample_indices, :] + for view_index in view_indices] + print(self.lb.inverse_transform(np.sign(pairwiseSPKMlikeSklearn.predict(self, view_list)))) + return self.lb.inverse_transform(np.sign(pairwiseSPKMlikeSklearn.predict(self, view_list))) + + def get_interpretation(self, directory, base_file_name, labels, multiclass=False): + u, v = self.feature_interpretability() + importances_sum = np.sum(u+v) + self.feature_importances_ = [u/importances_sum, v/importances_sum] + return "" + + def accepts_multi_class(self, random_state, n_samples=10, dim=2, + n_classes=3, n_views=2): + return False diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index b7fe1186..b40161da 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -30,6 +30,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): feature_importances[classifier_result.view_name] = pd.DataFrame( index=feature_ids[classifier_result.view_index]) if hasattr(classifier_result.clf, 'feature_importances_'): + print(classifier_result.classifier_name, classifier_result.view_name) feature_importances[classifier_result.view_name][ classifier_result.classifier_name] = classifier_result.clf.feature_importances_ else: @@ -44,7 +45,11 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): v_feature_id] feature_importances["mv"] = pd.DataFrame(index=feat_ids) if hasattr(classifier_result.clf, 'feature_importances_'): - feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_ + if len(classifier_result.clf.feature_importances_)==len(feature_ids): + concat = np.concatenate(classifier_result.clf.feature_importances_, axis=0) + feature_importances["mv"][classifier_result.classifier_name] = concat/np.sum(concat) + else: + feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_ return feature_importances -- GitLab