Skip to content
Snippets Groups Projects
Commit ceb55dac authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Tests passing

parent f8ca7c7c
Branches
Tags
No related merge requests found
Pipeline #3920 passed
......@@ -7,11 +7,11 @@ classifier_class_name = "WeightedLinearLateFusion"
class WeightedLinearLateFusion(LateFusionClassifier):
def __init__(self, random_state, classifier_names=None,
def __init__(self, random_state, classifiers_names=None,
classifier_configs=None, weights=None, nb_cores=1):
self.need_probas=True
super(WeightedLinearLateFusion, self).__init__(random_state=random_state,
classifier_names=classifier_names,
classifiers_names=classifiers_names,
classifier_configs=classifier_configs,
nb_cores=nb_cores,weights=weights)
......
......@@ -183,7 +183,11 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
folds = list(self.cv.split(self.available_indices, y[self.available_indices]))
if self.equivalent_draws:
self.n_iter = self.n_iter*X.nb_view
candidate_params = list(self._get_param_iterator())
# Fix to allow sklearn > 0.19
from sklearn.model_selection import ParameterSampler
candidate_params = list(
ParameterSampler(self.param_distributions, self.n_iter,
random_state=self.random_state))
base_estimator = clone(self.estimator)
results = {}
self.cv_results_ = dict(("param_"+param_name, []) for param_name in candidate_params[0].keys())
......
......@@ -210,7 +210,7 @@ class Test_format_previous_results(unittest.TestCase):
biclass_results["01"]["example_errors"][1]["mv"] = mv_error_data_2
# Running the function
metric_analysis, error_analysis, feature_importances, feature_stds = result_analysis.format_previous_results(biclass_results)
metric_analysis, error_analysis, feature_importances, feature_stds,labels = result_analysis.format_previous_results(biclass_results)
mean_df = pd.DataFrame(data=np.mean(np.array([metrics_1_data,
metrics_2_data]),
axis=0),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment