From 74ad04f0615ab238680ed24ea08127c80ab785f7 Mon Sep 17 00:00:00 2001 From: Dominique Benielli <dominique.benielli@univ-amu.fr> Date: Thu, 27 Feb 2025 10:55:52 +0100 Subject: [PATCH] work on faillure in tests --- summit/multiview_platform/utils/base.py | 6 +++--- summit/multiview_platform/utils/dataset.py | 2 +- summit/multiview_platform/utils/hyper_parameter_search.py | 7 +++---- .../tests/test_mono_view/test_exec_classif_mono_view.py | 2 +- summit/tests/test_utils/test_hyper_parameter_search.py | 8 ++------ 5 files changed, 10 insertions(+), 15 deletions(-) diff --git a/summit/multiview_platform/utils/base.py b/summit/multiview_platform/utils/base.py index 88fe4a06..21c8e7e3 100644 --- a/summit/multiview_platform/utils/base.py +++ b/summit/multiview_platform/utils/base.py @@ -68,7 +68,7 @@ class BaseClassifier(BaseEstimator, ): def get_base_estimator(self, estimator, estimator_config): if estimator_config is None: estimator_config = {} - if base_estimator is None: + if estimator is None: return DecisionTreeClassifier(**estimator_config) if isinstance(estimator, str): # pragma: no cover if estimator == "DecisionTreeClassifier": @@ -80,13 +80,13 @@ class BaseClassifier(BaseEstimator, ): else: raise ValueError( 'Base estimator string {} does not match an available classifier.'.format( - base_estimator)) + estimator)) elif isinstance(estimator, BaseEstimator): return estimator.set_params(**estimator_config) else: raise ValueError( 'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format( - type(base_estimator))) + type(estimator))) def to_str(self, param_name): """ diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py index 2a33b34b..c474f8ad 100644 --- a/summit/multiview_platform/utils/dataset.py +++ b/summit/multiview_platform/utils/dataset.py @@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): """ selected_labels = self.get_labels(sample_indices) if decode: - return [label_name.decode("utf-8") + return [label_name for label, label_name in enumerate(self.dataset["Labels"].attrs["names"]) if label in selected_labels] diff --git a/summit/multiview_platform/utils/hyper_parameter_search.py b/summit/multiview_platform/utils/hyper_parameter_search.py index 4caa3dc0..c5cf33f1 100644 --- a/summit/multiview_platform/utils/hyper_parameter_search.py +++ b/summit/multiview_platform/utils/hyper_parameter_search.py @@ -153,8 +153,8 @@ class HPSearch: class Random(RandomizedSearchCV, HPSearch): def __init__(self, estimator, param_distributions=None, n_iter=10, - refit=False, n_jobs=1, scoring=None, cv=None, - random_state=None, error_score=np.nan, view_indices=None, + refit=False, n_jobs=1, scoring=None, cv=None, learning_indices=None, + random_state=None, view_indices=None, framework="monoview", equivalent_draws=True, track_tracebacks=True): param_distributions = self.get_param_distribs(estimator, param_distributions) @@ -165,9 +165,8 @@ class Random(RandomizedSearchCV, HPSearch): param_distributions=param_distributions, refit=refit, n_jobs=n_jobs, scoring=scoring, cv=cv, random_state=random_state) - self.framework = framework - self.available_indices = error_score + self.available_indices = learning_indices self.view_indices = view_indices self.equivalent_draws = equivalent_draws self.track_tracebacks = track_tracebacks diff --git a/summit/tests/test_mono_view/test_exec_classif_mono_view.py b/summit/tests/test_mono_view/test_exec_classif_mono_view.py index 5a7b884c..4bd74941 100644 --- a/summit/tests/test_mono_view/test_exec_classif_mono_view.py +++ b/summit/tests/test_mono_view/test_exec_classif_mono_view.py @@ -180,7 +180,7 @@ class Test_exec_monoview(unittest.TestCase): feature_ids=[str(i) for i in range(test_dataset.get_v(0).shape[1])], **{"classifier_name": "decision_tree", "view_index": 0, - "decision_tree": {}}) + "decision_tree": {}}, ) rm_tmp() # class Test_getKWARGS(unittest.TestCase): diff --git a/summit/tests/test_utils/test_hyper_parameter_search.py b/summit/tests/test_utils/test_hyper_parameter_search.py index 82e2dda0..8a5c9ec3 100644 --- a/summit/tests/test_utils/test_hyper_parameter_search.py +++ b/summit/tests/test_utils/test_hyper_parameter_search.py @@ -66,7 +66,7 @@ class Test_Random(unittest.TestCase): cls.scoring = make_scorer(accuracy_score, ) cls.cv = StratifiedKFold(n_splits=n_splits, ) cls.random_state = np.random.RandomState(42) - cls.learning_indices = np.array([0, 1, 2, 3, 4, ]) + # cls.learning_indices = np.array([0, 1, 2, 3, 4, ]) cls.view_indices = None cls.framework = "monoview" cls.equivalent_draws = False @@ -78,7 +78,7 @@ class Test_Random(unittest.TestCase): self.estimator, self.param_distributions, n_iter=self.n_iter, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, + view_indices=self.view_indices, framework=self.framework, equivalent_draws=self.equivalent_draws ) @@ -89,7 +89,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework=self.framework, equivalent_draws=self.equivalent_draws @@ -106,7 +105,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=self.equivalent_draws @@ -121,7 +119,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=True @@ -137,7 +134,6 @@ class Test_Random(unittest.TestCase): refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, random_state=self.random_state, - learning_indices=self.learning_indices, view_indices=self.view_indices, framework="multiview", equivalent_draws=False -- GitLab