Skip to content
Snippets Groups Projects
Commit 74ad04f0 authored by Dominique Benielli's avatar Dominique Benielli
Browse files

work on faillure in tests

parent 3361b380
No related branches found
No related tags found
No related merge requests found
...@@ -68,7 +68,7 @@ class BaseClassifier(BaseEstimator, ): ...@@ -68,7 +68,7 @@ class BaseClassifier(BaseEstimator, ):
def get_base_estimator(self, estimator, estimator_config): def get_base_estimator(self, estimator, estimator_config):
if estimator_config is None: if estimator_config is None:
estimator_config = {} estimator_config = {}
if base_estimator is None: if estimator is None:
return DecisionTreeClassifier(**estimator_config) return DecisionTreeClassifier(**estimator_config)
if isinstance(estimator, str): # pragma: no cover if isinstance(estimator, str): # pragma: no cover
if estimator == "DecisionTreeClassifier": if estimator == "DecisionTreeClassifier":
...@@ -80,13 +80,13 @@ class BaseClassifier(BaseEstimator, ): ...@@ -80,13 +80,13 @@ class BaseClassifier(BaseEstimator, ):
else: else:
raise ValueError( raise ValueError(
'Base estimator string {} does not match an available classifier.'.format( 'Base estimator string {} does not match an available classifier.'.format(
base_estimator)) estimator))
elif isinstance(estimator, BaseEstimator): elif isinstance(estimator, BaseEstimator):
return estimator.set_params(**estimator_config) return estimator.set_params(**estimator_config)
else: else:
raise ValueError( raise ValueError(
'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format( 'base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(
type(base_estimator))) type(estimator)))
def to_str(self, param_name): def to_str(self, param_name):
""" """
......
...@@ -503,7 +503,7 @@ class HDF5Dataset(Dataset): ...@@ -503,7 +503,7 @@ class HDF5Dataset(Dataset):
""" """
selected_labels = self.get_labels(sample_indices) selected_labels = self.get_labels(sample_indices)
if decode: if decode:
return [label_name.decode("utf-8") return [label_name
for label, label_name in for label, label_name in
enumerate(self.dataset["Labels"].attrs["names"]) enumerate(self.dataset["Labels"].attrs["names"])
if label in selected_labels] if label in selected_labels]
......
...@@ -153,8 +153,8 @@ class HPSearch: ...@@ -153,8 +153,8 @@ class HPSearch:
class Random(RandomizedSearchCV, HPSearch): class Random(RandomizedSearchCV, HPSearch):
def __init__(self, estimator, param_distributions=None, n_iter=10, def __init__(self, estimator, param_distributions=None, n_iter=10,
refit=False, n_jobs=1, scoring=None, cv=None, refit=False, n_jobs=1, scoring=None, cv=None, learning_indices=None,
random_state=None, error_score=np.nan, view_indices=None, random_state=None, view_indices=None,
framework="monoview", framework="monoview",
equivalent_draws=True, track_tracebacks=True): equivalent_draws=True, track_tracebacks=True):
param_distributions = self.get_param_distribs(estimator, param_distributions) param_distributions = self.get_param_distribs(estimator, param_distributions)
...@@ -165,9 +165,8 @@ class Random(RandomizedSearchCV, HPSearch): ...@@ -165,9 +165,8 @@ class Random(RandomizedSearchCV, HPSearch):
param_distributions=param_distributions, param_distributions=param_distributions,
refit=refit, n_jobs=n_jobs, scoring=scoring, refit=refit, n_jobs=n_jobs, scoring=scoring,
cv=cv, random_state=random_state) cv=cv, random_state=random_state)
self.framework = framework self.framework = framework
self.available_indices = error_score self.available_indices = learning_indices
self.view_indices = view_indices self.view_indices = view_indices
self.equivalent_draws = equivalent_draws self.equivalent_draws = equivalent_draws
self.track_tracebacks = track_tracebacks self.track_tracebacks = track_tracebacks
......
...@@ -180,7 +180,7 @@ class Test_exec_monoview(unittest.TestCase): ...@@ -180,7 +180,7 @@ class Test_exec_monoview(unittest.TestCase):
feature_ids=[str(i) for i in range(test_dataset.get_v(0).shape[1])], feature_ids=[str(i) for i in range(test_dataset.get_v(0).shape[1])],
**{"classifier_name": "decision_tree", **{"classifier_name": "decision_tree",
"view_index": 0, "view_index": 0,
"decision_tree": {}}) "decision_tree": {}}, )
rm_tmp() rm_tmp()
# class Test_getKWARGS(unittest.TestCase): # class Test_getKWARGS(unittest.TestCase):
......
...@@ -66,7 +66,7 @@ class Test_Random(unittest.TestCase): ...@@ -66,7 +66,7 @@ class Test_Random(unittest.TestCase):
cls.scoring = make_scorer(accuracy_score, ) cls.scoring = make_scorer(accuracy_score, )
cls.cv = StratifiedKFold(n_splits=n_splits, ) cls.cv = StratifiedKFold(n_splits=n_splits, )
cls.random_state = np.random.RandomState(42) cls.random_state = np.random.RandomState(42)
cls.learning_indices = np.array([0, 1, 2, 3, 4, ]) # cls.learning_indices = np.array([0, 1, 2, 3, 4, ])
cls.view_indices = None cls.view_indices = None
cls.framework = "monoview" cls.framework = "monoview"
cls.equivalent_draws = False cls.equivalent_draws = False
...@@ -78,7 +78,7 @@ class Test_Random(unittest.TestCase): ...@@ -78,7 +78,7 @@ class Test_Random(unittest.TestCase):
self.estimator, self.param_distributions, n_iter=self.n_iter, self.estimator, self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv,
random_state=self.random_state, random_state=self.random_state,
learning_indices=self.learning_indices, view_indices=self.view_indices, view_indices=self.view_indices,
framework=self.framework, framework=self.framework,
equivalent_draws=self.equivalent_draws equivalent_draws=self.equivalent_draws
) )
...@@ -89,7 +89,6 @@ class Test_Random(unittest.TestCase): ...@@ -89,7 +89,6 @@ class Test_Random(unittest.TestCase):
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv, cv=self.cv,
random_state=self.random_state, random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices, view_indices=self.view_indices,
framework=self.framework, framework=self.framework,
equivalent_draws=self.equivalent_draws equivalent_draws=self.equivalent_draws
...@@ -106,7 +105,6 @@ class Test_Random(unittest.TestCase): ...@@ -106,7 +105,6 @@ class Test_Random(unittest.TestCase):
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv, cv=self.cv,
random_state=self.random_state, random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices, view_indices=self.view_indices,
framework="multiview", framework="multiview",
equivalent_draws=self.equivalent_draws equivalent_draws=self.equivalent_draws
...@@ -121,7 +119,6 @@ class Test_Random(unittest.TestCase): ...@@ -121,7 +119,6 @@ class Test_Random(unittest.TestCase):
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv, cv=self.cv,
random_state=self.random_state, random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices, view_indices=self.view_indices,
framework="multiview", framework="multiview",
equivalent_draws=True equivalent_draws=True
...@@ -137,7 +134,6 @@ class Test_Random(unittest.TestCase): ...@@ -137,7 +134,6 @@ class Test_Random(unittest.TestCase):
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv, cv=self.cv,
random_state=self.random_state, random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices, view_indices=self.view_indices,
framework="multiview", framework="multiview",
equivalent_draws=False equivalent_draws=False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment