From 2feb36c2bcd48155bd26d62250b55ef0f83477c3 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Tue, 3 Mar 2020 14:44:47 +0100 Subject: [PATCH] HPS --- .../utils/hyper_parameter_search.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 266e7f84..8132b9ab 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -222,6 +222,23 @@ class Grid(GridSearchCV, HPSearch): self.n_iter = len(self.candidate_params) +# class ParameterSamplerGrid: +# +# def __init__(self, param_distributions, n_iter): +# from math import floor +# n_points_per_param = int(n_iter **(1/len(param_distributions))) +# selected_params = dict((param_name, []) +# for param_name in param_distributions.keys()) +# for param_name, distribution in param_distributions.items(): +# if isinstance(distribution, list): +# if len(distribution)<n_points_per_param: +# selected_params[param_name] = distribution +# else: +# index_step = floor(len(distribution)/n_points_per_param-2) +# selected_params[param_name] = distribution[0]+[distribution[index*index_step+1] +# for index +# in range(n_points_per_param)] + -- GitLab