Skip to content
Snippets Groups Projects
Commit 1fbadee1 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added check view in mumbo and some intepretation

parent 186b1106
No related branches found
No related tags found
No related merge requests found
......@@ -21,11 +21,11 @@ split: 0.25
nb_folds: 5
nb_class:
classes:
type: ["multiview", "monoview"]
type: ["multiview",]
algos_monoview: ["cb_boost","decision_tree", "random_forest"]
algos_multiview: ["mumbo","mvml", "lp_norm_mkl", "mucombo"]
stats_iter: 10
algos_multiview: ["mumbo",]
stats_iter: 2
metrics: ["accuracy_score", "f1_score"]
metric_princ: "accuracy_score"
hps_type: "randomized_search-equiv"
hps_iter: 5
\ No newline at end of file
hps_iter: 2
\ No newline at end of file
......@@ -30,6 +30,9 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
self.used_views = view_indices
self.view_names = [X.get_view_name(view_index)
for view_index in view_indices]
numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices,
view_indices=view_indices)
return MumboClassifier.fit(self, numpy_X, y[train_indices],
......@@ -39,10 +42,24 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
example_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
if not np.array_equiv(np.sort(view_indices,axis=0), np.sort(self.used_views,axis=0)):
raise ValueError("Fitted with {} views, asking a prediction on {}".format(self.used_views, view_indices))
numpy_X, view_limits = X.to_numpy_array(example_indices=example_indices,
view_indices=view_indices)
return MumboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, labels, multiclass=False):
intepret_string = "Mumbo used "+str(len(self.best_views_)) +" iterations to converge, selecting views : \n" + ", ".join(map(str, self.best_views_)) + "\n\n With estimator weights : \n"+ "\n".join(map(str,self.estimator_weights_/np.sum(self.estimator_weights_)))
return intepret_string
self.view_importances = np.zeros(len(self.used_views))
for best_view, estimator_weight in zip(self.best_views_, self.estimator_weights_):
self.view_importances[best_view] += estimator_weight
self.view_importances /= np.sum(self.view_importances)
sorted_view_indices = np.argsort(-self.view_importances)
interpret_string = "Mumbo used {} iterations to converge.".format(self.best_views_.shape[0])
interpret_string+= "\n\nViews importance : \n"
for view_index in sorted_view_indices:
interpret_string+="- View {}({}), importance {}\n".format(view_index,
self.view_names[view_index],
self.view_importances[view_index])
interpret_string +="\n The boosting process selected views : \n" + ", ".join(map(str, self.best_views_))
interpret_string+="\n\n With estimator weights : \n"+ "\n".join(map(str,self.estimator_weights_/np.sum(self.estimator_weights_)))
return interpret_string
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment