diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 191029796b4b38a951643351c60d69803a61edf7..80d7ee63c269cc58917e961872e42b6555c186c6 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -16,59 +16,35 @@ from .base import get_metric from .. import metrics -def search_best_settings(dataset_var, labels, classifier_module, - classifier_name, - metrics, learning_indices, i_k_folds, random_state, - directory, views_indices=None, nb_cores=1, - searching_tool="randomized_search-equiv", n_iter=1, - classifier_config=None): - """Used to select the right hyper-parameter optimization function - to optimize hyper parameters""" - if views_indices is None: - views_indices = list(range(dataset_var.get_nb_view)) - output_file_name = directory - thismodule = sys.modules[__name__] - if searching_tool is not "None": - searching_tool_method = getattr(thismodule, - searching_tool.split("-")[0]) - best_settings, scores, params = searching_tool_method( - dataset_var, labels, "multiview", random_state, output_file_name, - classifier_module, classifier_name, i_k_folds, - nb_cores, metrics, n_iter, classifier_config, - learning_indices=learning_indices, view_indices=views_indices, - equivalent_draws=searching_tool.endswith("equiv")) - gen_report(params, scores, directory, ) - else: - best_settings = classifier_config - return best_settings # or well set clasifier ? +# def search_best_settings(dataset_var, labels, classifier_module, +# classifier_name, +# metrics, learning_indices, i_k_folds, random_state, +# directory, views_indices=None, nb_cores=1, +# searching_tool="randomized_search-equiv", n_iter=1, +# classifier_config=None): +# """Used to select the right hyper-parameter optimization function +# to optimize hyper parameters""" +# if views_indices is None: +# views_indices = list(range(dataset_var.get_nb_view)) +# output_file_name = directory +# thismodule = sys.modules[__name__] +# if searching_tool is not "None": +# searching_tool_method = getattr(thismodule, +# searching_tool.split("-")[0]) +# best_settings, scores, params = searching_tool_method( +# dataset_var, labels, "multiview", random_state, output_file_name, +# classifier_module, classifier_name, i_k_folds, +# nb_cores, metrics, n_iter, classifier_config, +# learning_indices=learning_indices, view_indices=views_indices, +# equivalent_draws=searching_tool.endswith("equiv")) +# gen_report(params, scores, directory, ) +# else: +# best_settings = classifier_config +# return best_settings # or well set clasifier ? class HPSearch: - # def __init__(self, y, framework, random_state, output_file_name, - # classifier_module, - # classifier_name, folds=4, nb_cores=1, - # metric = [["accuracy_score", None]], - # classifier_kwargs={}, learning_indices=None, - # view_indices=None, - # track_tracebacks=True): - # estimator = getattr(classifier_module, classifier_name)( - # random_state=random_state, - # **classifier_kwargs) - # self.init_params() - # self.estimator = get_mc_estim(estimator, random_state, - # multiview=(framework == "multiview"), - # y=y) - # self.folds = folds - # self.nb_cores = nb_cores - # self.clasifier_kwargs = classifier_kwargs - # self.learning_indices = learning_indices - # self.view_indices = view_indices - # self.output_file_name = output_file_name - # metric_module, metric_kwargs = get_metric(metric) - # self.scorer = metric_module.get_scorer(**metric_kwargs) - # self.track_tracebacks = track_tracebacks - def get_scoring(self, metric): if isinstance(metric, list): metric_module, metric_kwargs = get_metric(metric) @@ -137,11 +113,6 @@ class HPSearch: self.n_splits_ = n_splits return self - @abstractmethod - def init_params(self): - self.params_dict = {} - raise NotImplementedError - @abstractmethod def get_candidate_params(self, X): raise NotImplementedError @@ -203,9 +174,6 @@ class Random(RandomizedSearchCV, HPSearch): return HPSearch.fit_multiview(self, X, y=y, groups=groups, **fit_params) - # def init_params(self,): - # self.params_dict = self.estimator.gen_distribs() - def get_candidate_params(self, X): if self.equivalent_draws: self.n_iter = self.n_iter * X.nb_view diff --git a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py index 01a4230fe6c8b1e1b9ef0e0fd3bc9eca58de48a3..41287784af397b9db1246c513d257bf8c8716407 100644 --- a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py +++ b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py @@ -49,7 +49,7 @@ class FakeEstimMV(BaseEstimator): -class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase): +class Test_Random(unittest.TestCase): @classmethod def setUpClass(cls): @@ -142,6 +142,27 @@ class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase): self.assertEqual(RSCV.best_params_["param1"], "return exact") +class Test_Grid(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.parameter_grid = {"param1":[5,6], "param2":[7,8]} + cls.estimator = FakeEstim() + + def test_simple(self): + grid = hyper_parameter_search.Grid(self.estimator, + param_grid=self.parameter_grid) + + def test_get_candidate_params(self): + grid = hyper_parameter_search.Grid(self.estimator, + param_grid=self.parameter_grid) + grid.get_candidate_params(None) + self.assertEqual(grid.candidate_params, [{"param1": 5, "param2": 7}, + {"param1": 5, "param2": 8}, + {"param1": 6, "param2": 7}, + {"param1": 6, "param2": 8}]) + + # if __name__ == '__main__': # # unittest.main() # suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search)