Skip to content
Snippets Groups Projects
Commit 43456222 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

added tracebacks tracking in HPS

parent 0ca73201
Branches
Tags
No related merge requests found
......@@ -3,6 +3,7 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
import traceback
from scipy.stats import randint, uniform
from sklearn.model_selection import RandomizedSearchCV
......@@ -181,7 +182,8 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
elif self.framework == "multiview":
return self.fit_multiview(X, y=y, groups=groups,**fit_params)
def fit_multiview(self, X, y=None, groups=None, **fit_params):
def fit_multiview(self, X, y=None, groups=None, track_tracebacks=True,
**fit_params):
n_splits = self.cv.get_n_splits(self.available_indices, y[self.available_indices])
folds = list(self.cv.split(self.available_indices, y[self.available_indices]))
if self.equivalent_draws:
......@@ -195,8 +197,11 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
results = {}
self.cv_results_ = dict(("param_"+param_name, []) for param_name in candidate_params[0].keys())
self.cv_results_["mean_test_score"] = []
n_failed = 0
tracebacks = []
for candidate_param_idx, candidate_param in enumerate(candidate_params):
test_scores = np.zeros(n_splits)+1000
try:
for fold_idx, (train_indices, test_indices) in enumerate(folds):
current_estimator = clone(base_estimator)
current_estimator.set_params(**candidate_param)
......@@ -219,6 +224,17 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
if cross_validation_score <= min(results.values()):
self.best_params_ = candidate_params[candidate_param_idx]
self.best_score_ = cross_validation_score
except:
if track_tracebacks:
n_failed += 1
tracebacks.append(traceback.format_exc())
else:
raise
if n_failed == self.n_iter:
raise ValueError(
'No fits were performed. All HP combination returned errors \n\n' + '\n'.join(
tracebacks))
if self.refit:
self.best_estimator_ = clone(base_estimator).set_params(**self.best_params_)
self.best_estimator_.fit(X, y, **fit_params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment