diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 654870bb55659a6a0e94a107a55a8b9cc3eaaf37..621ee28050d0e075ff26cc18b83442993d5026cc 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -255,7 +255,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): self.cv_results_["mean_test_score"].append( cross_validation_score) results[candidate_param_idx] = cross_validation_score - if cross_validation_score >= min(results.values()): + if cross_validation_score >= max(results.values()): self.best_params_ = candidate_params[candidate_param_idx] self.best_score_ = cross_validation_score except: @@ -269,9 +269,6 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): 'No fits were performed. All HP combination returned errors \n\n' + '\n'.join( tracebacks)) self.cv_results_["mean_test_score"] = np.array(self.cv_results_["mean_test_score"]) - # for key, value in self.cv_results_.items(): - # if key.startswith("param_"): - # self.cv_results_[key] = np.ma.array(data=value, mask=[False for _ in value]) if self.refit: self.best_estimator_ = clone(base_estimator).set_params( **self.best_params_)