Skip to content
Snippets Groups Projects
Unverified Commit 2bc5e81a authored by Dominique Benielli's avatar Dominique Benielli Committed by GitHub
Browse files

Update base.py

parent 2a325003
No related branches found
No related tags found
No related merge requests found
......@@ -65,24 +65,24 @@ class BaseClassifier(BaseEstimator, ):
else:
return self.__class__.__name__ + " with no config."
def get_base_estimator(self, base_estimator, estimator_config):
def get_base_estimator(self, estimator, estimator_config):
if estimator_config is None:
estimator_config = {}
if base_estimator is None:
return DecisionTreeClassifier(**estimator_config)
if isinstance(base_estimator, str): # pragma: no cover
if base_estimator == "DecisionTreeClassifier":
if isinstance(estimator, str): # pragma: no cover
if estimator == "DecisionTreeClassifier":
return DecisionTreeClassifier(**estimator_config)
elif base_estimator == "AdaboostClassifier":
elif estimator == "AdaboostClassifier":
return AdaBoostClassifier(**estimator_config)
elif base_estimator == "RandomForestClassifier":
elif estimator == "RandomForestClassifier":
return RandomForestClassifier(**estimator_config)
else:
raise ValueError(
'Base estimator string {} does not match an available classifier.'.format(
base_estimator))
elif isinstance(base_estimator, BaseEstimator):
return base_estimator.set_params(**estimator_config)
elif isinstance(estimator, BaseEstimator):
return estimator.set_params(**estimator_config)
else:
raise ValueError(
'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment