Skip to content
Snippets Groups Projects
Commit 86f936d8 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Tests grid search

parent 4eca64c9
No related branches found
No related tags found
No related merge requests found
......@@ -16,59 +16,35 @@ from .base import get_metric
from .. import metrics
def search_best_settings(dataset_var, labels, classifier_module,
classifier_name,
metrics, learning_indices, i_k_folds, random_state,
directory, views_indices=None, nb_cores=1,
searching_tool="randomized_search-equiv", n_iter=1,
classifier_config=None):
"""Used to select the right hyper-parameter optimization function
to optimize hyper parameters"""
if views_indices is None:
views_indices = list(range(dataset_var.get_nb_view))
output_file_name = directory
thismodule = sys.modules[__name__]
if searching_tool is not "None":
searching_tool_method = getattr(thismodule,
searching_tool.split("-")[0])
best_settings, scores, params = searching_tool_method(
dataset_var, labels, "multiview", random_state, output_file_name,
classifier_module, classifier_name, i_k_folds,
nb_cores, metrics, n_iter, classifier_config,
learning_indices=learning_indices, view_indices=views_indices,
equivalent_draws=searching_tool.endswith("equiv"))
gen_report(params, scores, directory, )
else:
best_settings = classifier_config
return best_settings # or well set clasifier ?
# def search_best_settings(dataset_var, labels, classifier_module,
# classifier_name,
# metrics, learning_indices, i_k_folds, random_state,
# directory, views_indices=None, nb_cores=1,
# searching_tool="randomized_search-equiv", n_iter=1,
# classifier_config=None):
# """Used to select the right hyper-parameter optimization function
# to optimize hyper parameters"""
# if views_indices is None:
# views_indices = list(range(dataset_var.get_nb_view))
# output_file_name = directory
# thismodule = sys.modules[__name__]
# if searching_tool is not "None":
# searching_tool_method = getattr(thismodule,
# searching_tool.split("-")[0])
# best_settings, scores, params = searching_tool_method(
# dataset_var, labels, "multiview", random_state, output_file_name,
# classifier_module, classifier_name, i_k_folds,
# nb_cores, metrics, n_iter, classifier_config,
# learning_indices=learning_indices, view_indices=views_indices,
# equivalent_draws=searching_tool.endswith("equiv"))
# gen_report(params, scores, directory, )
# else:
# best_settings = classifier_config
# return best_settings # or well set clasifier ?
class HPSearch:
# def __init__(self, y, framework, random_state, output_file_name,
# classifier_module,
# classifier_name, folds=4, nb_cores=1,
# metric = [["accuracy_score", None]],
# classifier_kwargs={}, learning_indices=None,
# view_indices=None,
# track_tracebacks=True):
# estimator = getattr(classifier_module, classifier_name)(
# random_state=random_state,
# **classifier_kwargs)
# self.init_params()
# self.estimator = get_mc_estim(estimator, random_state,
# multiview=(framework == "multiview"),
# y=y)
# self.folds = folds
# self.nb_cores = nb_cores
# self.clasifier_kwargs = classifier_kwargs
# self.learning_indices = learning_indices
# self.view_indices = view_indices
# self.output_file_name = output_file_name
# metric_module, metric_kwargs = get_metric(metric)
# self.scorer = metric_module.get_scorer(**metric_kwargs)
# self.track_tracebacks = track_tracebacks
def get_scoring(self, metric):
if isinstance(metric, list):
metric_module, metric_kwargs = get_metric(metric)
......@@ -137,11 +113,6 @@ class HPSearch:
self.n_splits_ = n_splits
return self
@abstractmethod
def init_params(self):
self.params_dict = {}
raise NotImplementedError
@abstractmethod
def get_candidate_params(self, X):
raise NotImplementedError
......@@ -203,9 +174,6 @@ class Random(RandomizedSearchCV, HPSearch):
return HPSearch.fit_multiview(self, X, y=y, groups=groups,
**fit_params)
# def init_params(self,):
# self.params_dict = self.estimator.gen_distribs()
def get_candidate_params(self, X):
if self.equivalent_draws:
self.n_iter = self.n_iter * X.nb_view
......
......@@ -49,7 +49,7 @@ class FakeEstimMV(BaseEstimator):
class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase):
class Test_Random(unittest.TestCase):
@classmethod
def setUpClass(cls):
......@@ -142,6 +142,27 @@ class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase):
self.assertEqual(RSCV.best_params_["param1"], "return exact")
class Test_Grid(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.parameter_grid = {"param1":[5,6], "param2":[7,8]}
cls.estimator = FakeEstim()
def test_simple(self):
grid = hyper_parameter_search.Grid(self.estimator,
param_grid=self.parameter_grid)
def test_get_candidate_params(self):
grid = hyper_parameter_search.Grid(self.estimator,
param_grid=self.parameter_grid)
grid.get_candidate_params(None)
self.assertEqual(grid.candidate_params, [{"param1": 5, "param2": 7},
{"param1": 5, "param2": 8},
{"param1": 6, "param2": 7},
{"param1": 6, "param2": 8}])
# if __name__ == '__main__':
# # unittest.main()
# suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment