Skip to content
Snippets Groups Projects
Commit 794d8283 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Mumbo

parent 2b5afd73
No related branches found
No related tags found
No related merge requests found
......@@ -21,13 +21,13 @@ split: 0.49
nb_folds: 2
nb_class: 2
classes:
type: ["multiview", "monoview"]
type: ["multiview",]
algos_monoview: ["decision_tree" ]
algos_multiview: ["weighted_linear_early_fusion",]
stats_iter: 2
metrics: ["accuracy_score", "f1_score"]
metric_princ: "accuracy_score"
hps_type: "Grid"
hps_type: "None"
hps_args:
n_iter: 4
equivalent_draws: False
......@@ -40,10 +40,7 @@ hps_args:
view_weights: [null]
monoview_classifier: ["decision_tree"]
monoview_classifier__max_depth: [1,2]
mumbo:
base_estimator: ["decision_tree", "random_forest"]
base_estimator__max_depth: [4]
n_estimator: [50, 500, 1000]
......@@ -51,6 +48,14 @@ hps_args:
######################################
## The Monoview Classifier arguments #
######################################
mumbo:
base_estimator__criterion: 'gini'
base_estimator__max_depth: 4
base_estimator__random_state: None
base_estimator__splitter: 'best'
best_view_mode: 'edge'
n_estimators: 50
#
#random_forest:
# n_estimators: [25]
......
from sklearn.tree import DecisionTreeClassifier
from sklearn.base import BaseEstimator
import numpy as np
import os
from multimodal.boosting.mumbo import MumboClassifier
from ..multiview.multiview_utils import BaseMultiviewClassifier
from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_examples_views_indices
from ..utils.base import base_boosting_estimators
from ..utils.organization import secure_file_path
from .. import monoview_classifiers
classifier_class_name = "Mumbo"
......@@ -69,9 +72,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
for view_index in view_indices]
numpy_X, view_limits = X.to_numpy_array(example_indices=train_indices,
view_indices=view_indices)
self.view_shapes = [view_limits[ind]-view_limits[ind-1]
if ind > 0
else view_limits[ind]
self.view_shapes = [view_limits[ind+1]-view_limits[ind]
for ind in range(len(self.used_views)) ]
return MumboClassifier.fit(self, numpy_X, y[train_indices],
......@@ -86,7 +87,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
view_indices=view_indices)
return MumboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, labels, multiclass=False):
def get_interpretation(self, directory, base_file_name, labels, multiclass=False):
self.view_importances = np.zeros(len(self.used_views))
self.feature_importances_ = [np.zeros(view_shape)
for view_shape in self.view_shapes]
......@@ -100,8 +101,14 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.feature_importances_ = [feature_importances/importances_sum
for feature_importances
in self.feature_importances_]
for feature_importances, view_name in zip(self.feature_importances_, self.view_names):
secure_file_path(os.path.join(directory, "feature_importances",
base_file_name+view_name+"-feature_importances.csv"))
np.savetxt(os.path.join(directory, "feature_importances",
base_file_name+view_name+"-feature_importances.csv"),
feature_importances, delimiter=',')
self.view_importances /= np.sum(self.view_importances)
np.savetxt(directory+"view_importances.csv", self.view_importances,
np.savetxt(os.path.join(directory, base_file_name+"view_importances.csv"), self.view_importances,
delimiter=',')
sorted_view_indices = np.argsort(-self.view_importances)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment