Skip to content
Snippets Groups Projects
Commit f9f5d88b authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Bug in hps

parent 00cecb45
Branches
Tags
No related merge requests found
...@@ -255,7 +255,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): ...@@ -255,7 +255,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
self.cv_results_["mean_test_score"].append( self.cv_results_["mean_test_score"].append(
cross_validation_score) cross_validation_score)
results[candidate_param_idx] = cross_validation_score results[candidate_param_idx] = cross_validation_score
if cross_validation_score >= min(results.values()): if cross_validation_score >= max(results.values()):
self.best_params_ = candidate_params[candidate_param_idx] self.best_params_ = candidate_params[candidate_param_idx]
self.best_score_ = cross_validation_score self.best_score_ = cross_validation_score
except: except:
...@@ -269,9 +269,6 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): ...@@ -269,9 +269,6 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
'No fits were performed. All HP combination returned errors \n\n' + '\n'.join( 'No fits were performed. All HP combination returned errors \n\n' + '\n'.join(
tracebacks)) tracebacks))
self.cv_results_["mean_test_score"] = np.array(self.cv_results_["mean_test_score"]) self.cv_results_["mean_test_score"] = np.array(self.cv_results_["mean_test_score"])
# for key, value in self.cv_results_.items():
# if key.startswith("param_"):
# self.cv_results_[key] = np.ma.array(data=value, mask=[False for _ in value])
if self.refit: if self.refit:
self.best_estimator_ = clone(base_estimator).set_params( self.best_estimator_ = clone(base_estimator).set_params(
**self.best_params_) **self.best_params_)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment