Skip to content
Snippets Groups Projects
Commit 43456222 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

added tracebacks tracking in HPS

parent 0ca73201
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import traceback
from scipy.stats import randint, uniform from scipy.stats import randint, uniform
from sklearn.model_selection import RandomizedSearchCV from sklearn.model_selection import RandomizedSearchCV
...@@ -181,7 +182,8 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): ...@@ -181,7 +182,8 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
elif self.framework == "multiview": elif self.framework == "multiview":
return self.fit_multiview(X, y=y, groups=groups,**fit_params) return self.fit_multiview(X, y=y, groups=groups,**fit_params)
def fit_multiview(self, X, y=None, groups=None, **fit_params): def fit_multiview(self, X, y=None, groups=None, track_tracebacks=True,
**fit_params):
n_splits = self.cv.get_n_splits(self.available_indices, y[self.available_indices]) n_splits = self.cv.get_n_splits(self.available_indices, y[self.available_indices])
folds = list(self.cv.split(self.available_indices, y[self.available_indices])) folds = list(self.cv.split(self.available_indices, y[self.available_indices]))
if self.equivalent_draws: if self.equivalent_draws:
...@@ -195,8 +197,11 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): ...@@ -195,8 +197,11 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
results = {} results = {}
self.cv_results_ = dict(("param_"+param_name, []) for param_name in candidate_params[0].keys()) self.cv_results_ = dict(("param_"+param_name, []) for param_name in candidate_params[0].keys())
self.cv_results_["mean_test_score"] = [] self.cv_results_["mean_test_score"] = []
n_failed = 0
tracebacks = []
for candidate_param_idx, candidate_param in enumerate(candidate_params): for candidate_param_idx, candidate_param in enumerate(candidate_params):
test_scores = np.zeros(n_splits)+1000 test_scores = np.zeros(n_splits)+1000
try:
for fold_idx, (train_indices, test_indices) in enumerate(folds): for fold_idx, (train_indices, test_indices) in enumerate(folds):
current_estimator = clone(base_estimator) current_estimator = clone(base_estimator)
current_estimator.set_params(**candidate_param) current_estimator.set_params(**candidate_param)
...@@ -219,6 +224,17 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): ...@@ -219,6 +224,17 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
if cross_validation_score <= min(results.values()): if cross_validation_score <= min(results.values()):
self.best_params_ = candidate_params[candidate_param_idx] self.best_params_ = candidate_params[candidate_param_idx]
self.best_score_ = cross_validation_score self.best_score_ = cross_validation_score
except:
if track_tracebacks:
n_failed += 1
tracebacks.append(traceback.format_exc())
else:
raise
if n_failed == self.n_iter:
raise ValueError(
'No fits were performed. All HP combination returned errors \n\n' + '\n'.join(
tracebacks))
if self.refit: if self.refit:
self.best_estimator_ = clone(base_estimator).set_params(**self.best_params_) self.best_estimator_ = clone(base_estimator).set_params(**self.best_params_)
self.best_estimator_.fit(X, y, **fit_params) self.best_estimator_.fit(X, y, **fit_params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment