Skip to content
Snippets Groups Projects
Commit d2326aa7 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Feat_imp_works with nans

parent 279c4c8e
No related branches found
No related tags found
No related merge requests found
Pipeline #11528 failed
...@@ -32,7 +32,8 @@ class IBSVMRBF(BaseMonoviewClassifier, BalancedBaggingClassifier): ...@@ -32,7 +32,8 @@ class IBSVMRBF(BaseMonoviewClassifier, BalancedBaggingClassifier):
for estim in self.estimators_: for estim in self.estimators_:
if hasattr(estim['classifier'], 'feature_importances_'): if hasattr(estim['classifier'], 'feature_importances_'):
self.feature_importances_ += estim['classifier'].feature_importances_ self.feature_importances_ += estim['classifier'].feature_importances_
self.feature_importances_ /= np.sum(self.feature_importances_) if np.sum(self.feature_importances_)!=0:
self.feature_importances_ /= np.sum(self.feature_importances_)
return self return self
......
from sklearn.ensemble import RandomForestClassifier
from ..monoview.monoview_utils import BaseMonoviewClassifier
from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
classifier_class_name = "RandomForest"
class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
"""
This class is an adaptation of scikit-learn's `RandomForestClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_
"""
def __init__(self, random_state=None, n_estimators=10,
max_depth=None, criterion='gini', **kwargs):
RandomForestClassifier.__init__(self,
n_estimators=n_estimators,
max_depth=max_depth,
criterion=criterion,
class_weight="balanced",
random_state=random_state
)
self.param_names = ["n_estimators", "max_depth", "criterion",
"random_state"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300),
CustomRandint(low=1, high=10),
["gini", "entropy"], [random_state]]
self.weird_strings = {}
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multiclass=False):
interpret_string = ""
interpret_string += self.get_feature_importance(directory,
base_file_name,
feature_ids)
return interpret_string
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment