diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 266e7f84690f47e3742e8d69bdfd03bfc2072282..8132b9ab4525eb58fe335d381277dd1fe0139d68 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -222,6 +222,23 @@ class Grid(GridSearchCV, HPSearch): self.n_iter = len(self.candidate_params) +# class ParameterSamplerGrid: +# +# def __init__(self, param_distributions, n_iter): +# from math import floor +# n_points_per_param = int(n_iter **(1/len(param_distributions))) +# selected_params = dict((param_name, []) +# for param_name in param_distributions.keys()) +# for param_name, distribution in param_distributions.items(): +# if isinstance(distribution, list): +# if len(distribution)<n_points_per_param: +# selected_params[param_name] = distribution +# else: +# index_step = floor(len(distribution)/n_points_per_param-2) +# selected_params[param_name] = distribution[0]+[distribution[index*index_step+1] +# for index +# in range(n_points_per_param)] +