Skip to content
Snippets Groups Projects
Commit 794d8283 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Mumbo

parent 2b5afd73
Branches
No related tags found
No related merge requests found
...@@ -21,13 +21,13 @@ split: 0.49 ...@@ -21,13 +21,13 @@ split: 0.49
nb_folds: 2 nb_folds: 2
nb_class: 2 nb_class: 2
classes: classes:
type: ["multiview", "monoview"] type: ["multiview",]
algos_monoview: ["decision_tree" ] algos_monoview: ["decision_tree" ]
algos_multiview: ["weighted_linear_early_fusion",] algos_multiview: ["weighted_linear_early_fusion",]
stats_iter: 2 stats_iter: 2
metrics: ["accuracy_score", "f1_score"] metrics: ["accuracy_score", "f1_score"]
metric_princ: "accuracy_score" metric_princ: "accuracy_score"
hps_type: "Grid" hps_type: "None"
hps_args: hps_args:
n_iter: 4 n_iter: 4
equivalent_draws: False equivalent_draws: False
...@@ -40,10 +40,7 @@ hps_args: ...@@ -40,10 +40,7 @@ hps_args:
view_weights: [null] view_weights: [null]
monoview_classifier: ["decision_tree"] monoview_classifier: ["decision_tree"]
monoview_classifier__max_depth: [1,2] monoview_classifier__max_depth: [1,2]
mumbo:
base_estimator: ["decision_tree", "random_forest"]
base_estimator__max_depth: [4]
n_estimator: [50, 500, 1000]
...@@ -51,6 +48,14 @@ hps_args: ...@@ -51,6 +48,14 @@ hps_args:
###################################### ######################################
## The Monoview Classifier arguments # ## The Monoview Classifier arguments #
###################################### ######################################
mumbo:
base_estimator__criterion: 'gini'
base_estimator__max_depth: 4
base_estimator__random_state: None
base_estimator__splitter: 'best'
best_view_mode: 'edge'
n_estimators: 50
# #
#random_forest: #random_forest:
# n_estimators: [25] # n_estimators: [25]
......
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
import numpy as np import numpy as np
import os
from multimodal.boosting.mumbo import MumboClassifier from multimodal.boosting.mumbo import MumboClassifier
from ..multiview.multiview_utils import BaseMultiviewClassifier from ..multiview.multiview_utils import BaseMultiviewClassifier
from ..utils.hyper_parameter_search import CustomRandint from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_examples_views_indices from ..utils.dataset import get_examples_views_indices
from ..utils.base import base_boosting_estimators from ..utils.base import base_boosting_estimators
from ..utils.organization import secure_file_path
from .. import monoview_classifiers from .. import monoview_classifiers
classifier_class_name = "Mumbo" classifier_class_name = "Mumbo"
...@@ -69,9 +72,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -69,9 +72,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
for view_index in view_indices] for view_index in view_indices]
numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices, numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices,
view_indices=view_indices) view_indices=view_indices)
self.view_shapes = [view_limits[ind]-view_limits[ind-1] self.view_shapes = [view_limits[ind+1]-view_limits[ind]
if ind > 0
else view_limits[ind]
for ind in range(len(self.used_views)) ] for ind in range(len(self.used_views)) ]
return MumboClassifier.fit(self, numpy_X, y[train_indices], return MumboClassifier.fit(self, numpy_X, y[train_indices],
...@@ -86,7 +87,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -86,7 +87,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
view_indices=view_indices) view_indices=view_indices)
return MumboClassifier.predict(self, numpy_X) return MumboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, labels, multiclass=False): def get_interpretation(self, directory, base_file_name, labels, multiclass=False):
self.view_importances = np.zeros(len(self.used_views)) self.view_importances = np.zeros(len(self.used_views))
self.feature_importances_ = [np.zeros(view_shape) self.feature_importances_ = [np.zeros(view_shape)
for view_shape in self.view_shapes] for view_shape in self.view_shapes]
...@@ -100,8 +101,14 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -100,8 +101,14 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.feature_importances_ = [feature_importances/importances_sum self.feature_importances_ = [feature_importances/importances_sum
for feature_importances for feature_importances
in self.feature_importances_] in self.feature_importances_]
for feature_importances, view_name in zip(self.feature_importances_, self.view_names):
secure_file_path(os.path.join(directory, "feature_importances",
base_file_name+view_name+"-feature_importances.csv"))
np.savetxt(os.path.join(directory, "feature_importances",
base_file_name+view_name+"-feature_importances.csv"),
feature_importances, delimiter=',')
self.view_importances /= np.sum(self.view_importances) self.view_importances /= np.sum(self.view_importances)
np.savetxt(directory+"view_importances.csv", self.view_importances, np.savetxt(os.path.join(directory, base_file_name+"view_importances.csv"), self.view_importances,
delimiter=',') delimiter=',')
sorted_view_indices = np.argsort(-self.view_importances) sorted_view_indices = np.argsort(-self.view_importances)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment