From d2326aa71ed25c8e5dc6ffe89494b9c00dc10f90 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Thu, 9 Mar 2023 12:40:32 -0500 Subject: [PATCH] Feat_imp_works with nans --- .../monoview_classifiers/ib_svm_rbf.py | 3 +- .../monoview_classifiers/random_foscm.py | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 summit/multiview_platform/monoview_classifiers/random_foscm.py diff --git a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py index 095bc7dc..5f6305e5 100644 --- a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py +++ b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py @@ -32,7 +32,8 @@ class IBSVMRBF(BaseMonoviewClassifier, BalancedBaggingClassifier): for estim in self.estimators_: if hasattr(estim['classifier'], 'feature_importances_'): self.feature_importances_ += estim['classifier'].feature_importances_ - self.feature_importances_ /= np.sum(self.feature_importances_) + if np.sum(self.feature_importances_)!=0: + self.feature_importances_ /= np.sum(self.feature_importances_) return self diff --git a/summit/multiview_platform/monoview_classifiers/random_foscm.py b/summit/multiview_platform/monoview_classifiers/random_foscm.py new file mode 100644 index 00000000..202ac61b --- /dev/null +++ b/summit/multiview_platform/monoview_classifiers/random_foscm.py @@ -0,0 +1,45 @@ +from sklearn.ensemble import RandomForestClassifier + +from ..monoview.monoview_utils import BaseMonoviewClassifier +from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint + +# Author-Info +__author__ = "Baptiste Bauvin" +__status__ = "Prototype" # Production, Development, Prototype + +classifier_class_name = "RandomForest" + + +class RandomForest(RandomForestClassifier, BaseMonoviewClassifier): + """ + This class is an adaptation of scikit-learn's `RandomForestClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_ + + + """ + + def __init__(self, random_state=None, n_estimators=10, + max_depth=None, criterion='gini', **kwargs): + + RandomForestClassifier.__init__(self, + n_estimators=n_estimators, + max_depth=max_depth, + criterion=criterion, + class_weight="balanced", + random_state=random_state + ) + self.param_names = ["n_estimators", "max_depth", "criterion", + "random_state"] + self.classed_params = [] + self.distribs = [CustomRandint(low=1, high=300), + CustomRandint(low=1, high=10), + ["gini", "entropy"], [random_state]] + self.weird_strings = {} + + def get_interpretation(self, directory, base_file_name, y_test, feature_ids, + multiclass=False): + + interpret_string = "" + interpret_string += self.get_feature_importance(directory, + base_file_name, + feature_ids) + return interpret_string -- GitLab