diff --git a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py index 095bc7dcd63864733d8b2ffa299378fc4f6646ca..5f6305e5f891b64e60c0a9c2bb131983461cdaf6 100644 --- a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py +++ b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py @@ -32,7 +32,8 @@ class IBSVMRBF(BaseMonoviewClassifier, BalancedBaggingClassifier): for estim in self.estimators_: if hasattr(estim['classifier'], 'feature_importances_'): self.feature_importances_ += estim['classifier'].feature_importances_ - self.feature_importances_ /= np.sum(self.feature_importances_) + if np.sum(self.feature_importances_)!=0: + self.feature_importances_ /= np.sum(self.feature_importances_) return self diff --git a/summit/multiview_platform/monoview_classifiers/random_foscm.py b/summit/multiview_platform/monoview_classifiers/random_foscm.py new file mode 100644 index 0000000000000000000000000000000000000000..202ac61b78c86db12927146edc63d58edc8dad39 --- /dev/null +++ b/summit/multiview_platform/monoview_classifiers/random_foscm.py @@ -0,0 +1,45 @@ +from sklearn.ensemble import RandomForestClassifier + +from ..monoview.monoview_utils import BaseMonoviewClassifier +from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint + +# Author-Info +__author__ = "Baptiste Bauvin" +__status__ = "Prototype" # Production, Development, Prototype + +classifier_class_name = "RandomForest" + + +class RandomForest(RandomForestClassifier, BaseMonoviewClassifier): + """ + This class is an adaptation of scikit-learn's `RandomForestClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_ + + + """ + + def __init__(self, random_state=None, n_estimators=10, + max_depth=None, criterion='gini', **kwargs): + + RandomForestClassifier.__init__(self, + n_estimators=n_estimators, + max_depth=max_depth, + criterion=criterion, + class_weight="balanced", + random_state=random_state + ) + self.param_names = ["n_estimators", "max_depth", "criterion", + "random_state"] + self.classed_params = [] + self.distribs = [CustomRandint(low=1, high=300), + CustomRandint(low=1, high=10), + ["gini", "entropy"], [random_state]] + self.weird_strings = {} + + def get_interpretation(self, directory, base_file_name, y_test, feature_ids, + multiclass=False): + + interpret_string = "" + interpret_string += self.get_feature_importance(directory, + base_file_name, + feature_ids) + return interpret_string