diff --git a/summit/multiview_platform/utils/base.py b/summit/multiview_platform/utils/base.py index 88fe4a064ba9cf6ef03cdb85a27290e9fa457008..21c8e7e3163b0908936daaca3af5b8f51b5e56c5 100644 --- a/summit/multiview_platform/utils/base.py +++ b/summit/multiview_platform/utils/base.py @@ -68,7 +68,7 @@ class BaseClassifier(BaseEstimator, ): def get_base_estimator(self, estimator, estimator_config): if estimator_config is None: estimator_config = {} - if base_estimator is None: + if estimator is None: return DecisionTreeClassifier(**estimator_config) if isinstance(estimator, str): # pragma: no cover if estimator == "DecisionTreeClassifier": @@ -80,13 +80,13 @@ class BaseClassifier(BaseEstimator, ): else: raise ValueError( 'Base estimator string {} does not match an available classifier.'.format( - base_estimator)) + estimator)) elif isinstance(estimator, BaseEstimator): return estimator.set_params(**estimator_config) else: raise ValueError( 'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format( - type(base_estimator))) + type(estimator))) def to_str(self, param_name): """ diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py index 2a33b34bac58766b39049efa68614ddde26b9522..c474f8ad6c942cae74c4dd1eeae813ac641e7cc8 100644 --- a/summit/multiview_platform/utils/dataset.py +++ b/summit/multiview_platform/utils/dataset.py @@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): """ selected_labels = self.get_labels(sample_indices) if decode: - return [label_name.decode("utf-8") + return [label_name for label, label_name in enumerate(self.dataset["Labels"].attrs["names"]) if label in selected_labels] diff --git a/summit/multiview_platform/utils/hyper_parameter_search.py b/summit/multiview_platform/utils/hyper_parameter_search.py index 4caa3dc034c3d5a1a34d9ce901e44b09438cbb55..c5cf33f1e12fd9f9250c2dc52b828d61024ff561 100644 --- a/summit/multiview_platform/utils/hyper_parameter_search.py +++ b/summit/multiview_platform/utils/hyper_parameter_search.py @@ -153,8 +153,8 @@ class HPSearch: class Random(RandomizedSearchCV, HPSearch): def __init__(self, estimator, param_distributions=None, n_iter=10, - refit=False, n_jobs=1, scoring=None, cv=None, - random_state=None, error_score=np.nan, view_indices=None, + refit=False, n_jobs=1, scoring=None, cv=None, learning_indices=None, + random_state=None, view_indices=None, framework="monoview", equivalent_draws=True, track_tracebacks=True): param_distributions = self.get_param_distribs(estimator, param_distributions) @@ -165,9 +165,8 @@ class Random(RandomizedSearchCV, HPSearch): param_distributions=param_distributions, refit=refit, n_jobs=n_jobs, scoring=scoring, cv=cv, random_state=random_state) - self.framework = framework - self.available_indices = error_score + self.available_indices = learning_indices self.view_indices = view_indices self.equivalent_draws = equivalent_draws self.track_tracebacks = track_tracebacks diff --git a/summit/tests/test_mono_view/test_exec_classif_mono_view.py b/summit/tests/test_mono_view/test_exec_classif_mono_view.py index 5a7b884c9bb436cbb3eef8727770f59b7829d2fb..4bd749412da438000945093e41279e8b1818d2ff 100644 --- a/summit/tests/test_mono_view/test_exec_classif_mono_view.py +++ b/summit/tests/test_mono_view/test_exec_classif_mono_view.py @@ -180,7 +180,7 @@ class Test_exec_monoview(unittest.TestCase): feature_ids=[str(i) for i in range(test_dataset.get_v(0).shape[1])], **{"classifier_name": "decision_tree", "view_index": 0, - "decision_tree": {}}) + "decision_tree": {}}, ) rm_tmp() # class Test_getKWARGS(unittest.TestCase): diff --git a/summit/tests/test_utils/test_hyper_parameter_search.py b/summit/tests/test_utils/test_hyper_parameter_search.py index 82e2dda0749b83f93b7822ca7a07131427b82d0f..8a5c9ec3d6fdca8bf6e50e5a408ae7ebf484a52e 100644 --- a/summit/tests/test_utils/test_hyper_parameter_search.py +++ b/summit/tests/test_utils/test_hyper_parameter_search.py @@ -66,7 +66,7 @@ class Test_Random(unittest.TestCase): cls.scoring = make_scorer(accuracy_score, ) cls.cv = StratifiedKFold(n_splits=n_splits, ) cls.random_state = np.random.RandomState(42) - cls.learning_indices = np.array([0, 1, 2, 3, 4, ]) + # cls.learning_indices = np.array([0, 1, 2, 3, 4, ]) cls.view_indices = None cls.framework = "monoview" cls.equivalent_draws = False @@ -78,7 +78,7 @@ class Test_Random(unittest.TestCase): self.estimator, self.param_distributions, n_iter=self.n_iter, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, + view_indices=self.view_indices, framework=self.framework, equivalent_draws=self.equivalent_draws ) @@ -89,7 +89,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework=self.framework, equivalent_draws=self.equivalent_draws @@ -106,7 +105,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=self.equivalent_draws @@ -121,7 +119,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=True @@ -137,7 +134,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=False