Skip to content
Snippets Groups Projects
Unverified Commit 2bc5e81a authored by Dominique Benielli's avatar Dominique Benielli Committed by GitHub
Browse files

Update base.py

parent 2a325003
No related branches found
No related tags found
No related merge requests found
...@@ -65,24 +65,24 @@ class BaseClassifier(BaseEstimator, ): ...@@ -65,24 +65,24 @@ class BaseClassifier(BaseEstimator, ):
else: else:
return self.__class__.__name__ + " with no config." return self.__class__.__name__ + " with no config."
def get_base_estimator(self, base_estimator, estimator_config): def get_base_estimator(self, estimator, estimator_config):
if estimator_config is None: if estimator_config is None:
estimator_config = {} estimator_config = {}
if base_estimator is None: if base_estimator is None:
return DecisionTreeClassifier(**estimator_config) return DecisionTreeClassifier(**estimator_config)
if isinstance(base_estimator, str): # pragma: no cover if isinstance(estimator, str): # pragma: no cover
if base_estimator == "DecisionTreeClassifier": if estimator == "DecisionTreeClassifier":
return DecisionTreeClassifier(**estimator_config) return DecisionTreeClassifier(**estimator_config)
elif base_estimator == "AdaboostClassifier": elif estimator == "AdaboostClassifier":
return AdaBoostClassifier(**estimator_config) return AdaBoostClassifier(**estimator_config)
elif base_estimator == "RandomForestClassifier": elif estimator == "RandomForestClassifier":
return RandomForestClassifier(**estimator_config) return RandomForestClassifier(**estimator_config)
else: else:
raise ValueError( raise ValueError(
'Base estimator string {} does not match an available classifier.'.format( 'Base estimator string {} does not match an available classifier.'.format(
base_estimator)) base_estimator))
elif isinstance(base_estimator, BaseEstimator): elif isinstance(estimator, BaseEstimator):
return base_estimator.set_params(**estimator_config) return estimator.set_params(**estimator_config)
else: else:
raise ValueError( raise ValueError(
'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format( 'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment