diff --git a/summit/multiview_platform/utils/base.py b/summit/multiview_platform/utils/base.py index 67df47a6eb850da32faa5ea132b1e73a0d746ace..88fe4a064ba9cf6ef03cdb85a27290e9fa457008 100644 --- a/summit/multiview_platform/utils/base.py +++ b/summit/multiview_platform/utils/base.py @@ -65,24 +65,24 @@ class BaseClassifier(BaseEstimator, ): else: return self.__class__.__name__ + " with no config." - def get_base_estimator(self, base_estimator, estimator_config): + def get_base_estimator(self, estimator, estimator_config): if estimator_config is None: estimator_config = {} if base_estimator is None: return DecisionTreeClassifier(**estimator_config) - if isinstance(base_estimator, str): # pragma: no cover - if base_estimator == "DecisionTreeClassifier": + if isinstance(estimator, str): # pragma: no cover + if estimator == "DecisionTreeClassifier": return DecisionTreeClassifier(**estimator_config) - elif base_estimator == "AdaboostClassifier": + elif estimator == "AdaboostClassifier": return AdaBoostClassifier(**estimator_config) - elif base_estimator == "RandomForestClassifier": + elif estimator == "RandomForestClassifier": return RandomForestClassifier(**estimator_config) else: raise ValueError( 'Base estimator string {} does not match an available classifier.'.format( base_estimator)) - elif isinstance(base_estimator, BaseEstimator): - return base_estimator.set_params(**estimator_config) + elif isinstance(estimator, BaseEstimator): + return estimator.set_params(**estimator_config) else: raise ValueError( 'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(