Skip to content
Snippets Groups Projects
Commit 8c865a7e authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added some interpretation on mumbo

parent 53ba4ccc
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,11 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -69,6 +69,11 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
for view_index in view_indices] for view_index in view_indices]
numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices, numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices,
view_indices=view_indices) view_indices=view_indices)
self.view_shapes = [view_limits[ind]-view_limits[ind-1]
if ind > 0
else view_limits[ind]
for ind in range(len(self.used_views))]
return MumboClassifier.fit(self, numpy_X, y[train_indices], return MumboClassifier.fit(self, numpy_X, y[train_indices],
view_limits) view_limits)
...@@ -83,11 +88,22 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -83,11 +88,22 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
def get_interpretation(self, directory, labels, multiclass=False): def get_interpretation(self, directory, labels, multiclass=False):
self.view_importances = np.zeros(len(self.used_views)) self.view_importances = np.zeros(len(self.used_views))
for best_view, estimator_weight in zip(self.best_views_, self.estimator_weights_): self.feature_importances_ = [np.zeros(view_shape)
for view_shape in self.view_shapes]
for best_view, estimator_weight, estimator in zip(self.best_views_, self.estimator_weights_, self.estimators_):
self.view_importances[best_view] += estimator_weight self.view_importances[best_view] += estimator_weight
if hasattr(estimator, "feature_importances_"):
self.feature_importances_[best_view] += estimator.feature_importances_
importances_sum = sum([np.sum(feature_importances)
for feature_importances
in self.feature_importances_])
self.feature_importances_ = [feature_importances/importances_sum
for feature_importances
in self.feature_importances_]
self.view_importances /= np.sum(self.view_importances) self.view_importances /= np.sum(self.view_importances)
np.savetxt(directory+"view_importances.csv", self.view_importances, np.savetxt(directory+"view_importances.csv", self.view_importances,
delimiter=',') delimiter=',')
sorted_view_indices = np.argsort(-self.view_importances) sorted_view_indices = np.argsort(-self.view_importances)
interpret_string = "Mumbo used {} iterations to converge.".format(self.best_views_.shape[0]) interpret_string = "Mumbo used {} iterations to converge.".format(self.best_views_.shape[0])
interpret_string+= "\n\nViews importance : \n" interpret_string+= "\n\nViews importance : \n"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment