From f9f5d88bfbd8e00517e5c93e4bbf3bd37aabecd6 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Thu, 27 Feb 2020 19:55:34 +0100 Subject: [PATCH] Bug in hps --- .../utils/hyper_parameter_search.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 654870bb..621ee280 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -255,7 +255,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): self.cv_results_["mean_test_score"].append( cross_validation_score) results[candidate_param_idx] = cross_validation_score - if cross_validation_score >= min(results.values()): + if cross_validation_score >= max(results.values()): self.best_params_ = candidate_params[candidate_param_idx] self.best_score_ = cross_validation_score except: @@ -269,9 +269,6 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): 'No fits were performed. All HP combination returned errors \n\n' + '\n'.join( tracebacks)) self.cv_results_["mean_test_score"] = np.array(self.cv_results_["mean_test_score"]) - # for key, value in self.cv_results_.items(): - # if key.startswith("param_"): - # self.cv_results_[key] = np.ma.array(data=value, mask=[False for _ in value]) if self.refit: self.best_estimator_ = clone(base_estimator).set_params( **self.best_params_) -- GitLab