diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py index 3d78d991bd40565fbf4da3aec1b81b547d5c7bb7..31e2546b6a93a52ee7d7d4a56ff5ea8ca8128784 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py @@ -4,7 +4,7 @@ import numpy as np from .. import monoview_classifiers from ..utils.base import BaseClassifier, ResultAnalyser -from ..utils.dataset import RAMDataset +from ..utils.dataset import RAMDataset, get_examples_views_indices class FakeEstimator(): @@ -29,6 +29,7 @@ class BaseMultiviewClassifier(BaseClassifier): self.random_state = random_state self.short_name = self.__module__.split(".")[-1] self.weird_strings = {} + self.used_views = None @abstractmethod def fit(self, X, y, train_indices=None, view_indices=None): @@ -38,6 +39,10 @@ class BaseMultiviewClassifier(BaseClassifier): def predict(self, X, example_indices=None, view_indices=None): pass + def _check_views(self, view_indices): + if self.used_views is not None and not np.array_equal(np.sort(self.used_views), np.sort(view_indices)): + raise ValueError('Used {} views to fit, and trying to predict on {}'.format(self.used_views, view_indices)) + def to_str(self, param_name): if param_name in self.weird_strings: string = "" diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py index 05e4cd05af16e4202eb5f8db07af2891510ca900..a49845191d950fa26026d7d5945ba5853275f199 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py @@ -30,6 +30,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier, train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) + self.used_views = view_indices # TODO : Finer analysis, may support a bit of mutliclass if np.unique(y[train_indices]).shape[0] > 2: raise ValueError( @@ -56,6 +57,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier, example_indices, view_indices = get_examples_views_indices(X, example_indices, view_indices) + self._check_views(view_indices) nb_class = X.get_nb_class() if nb_class > 2: nb_class = 3 diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py index f657c6c2d4dd65e6453b28a23870e32923ad0b85..e9cbac4c770a826183d713d691f7bcee25225cbe 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py @@ -27,6 +27,7 @@ class BaseJumboFusion(LateFusionClassifier): example_indices, view_indices = get_examples_views_indices(X, example_indices, view_indices) + self._check_views(view_indices) monoview_decisions = self.predict_monoview(X, example_indices=example_indices, view_indices=view_indices) @@ -36,6 +37,7 @@ class BaseJumboFusion(LateFusionClassifier): train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) + self.used_views = view_indices self.init_classifiers(len(view_indices), nb_monoview_per_view=self.nb_monoview_per_view) self.fit_monoview_estimators(X, y, train_indices=train_indices, diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py index e2e8da5db99dd6a63fa35bd360c7b6fdff951c6a..d6ff8b4c8a7f585707959ad055accbe074afc8d8 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py @@ -97,6 +97,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier): train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) + self.used_views = view_indices if np.unique(y).shape[0] > 2: multiclass = True else: diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py index 6aa2a7bedbf7d21501983827d2463fb14f4a527b..5fbd4d56aeb6ae4b5bec4f6c8be8e25f24473c44 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py @@ -6,59 +6,59 @@ def get_names(classed_list): return np.array([object_.__class__.__name__ for object_ in classed_list]) -class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin): +# class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin): +# +# def __init__(self, random_state): +# self.random_state = random_state +# +# def genBestParams(self, detector): +# return dict((param_name, detector.best_params_[param_name]) +# for param_name in self.param_names) +# +# def genParamsFromDetector(self, detector): +# if self.classed_params: +# classed_dict = dict((classed_param, get_names( +# detector.cv_results_["param_" + classed_param])) +# for classed_param in self.classed_params) +# if self.param_names: +# return [(param_name, +# np.array(detector.cv_results_["param_" + param_name])) +# if param_name not in self.classed_params else ( +# param_name, classed_dict[param_name]) +# for param_name in self.param_names] +# else: +# return [()] +# +# def genDistribs(self): +# return dict((param_name, distrib) for param_name, distrib in +# zip(self.param_names, self.distribs)) +# +# def getConfig(self): +# if self.param_names: +# return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join( +# [param_name + " : " + self.to_str(param_name) for param_name in +# self.param_names]) +# else: +# return "\n\t\t- " + self.__class__.__name__ + "with no config." +# +# def to_str(self, param_name): +# if param_name in self.weird_strings: +# if self.weird_strings[param_name] == "class_name": +# return self.get_params()[param_name].__class__.__name__ +# else: +# return self.weird_strings[param_name]( +# self.get_params()[param_name]) +# else: +# return str(self.get_params()[param_name]) +# +# def get_interpretation(self): +# return "No detailed interpretation function" - def __init__(self, random_state): - self.random_state = random_state - - def genBestParams(self, detector): - return dict((param_name, detector.best_params_[param_name]) - for param_name in self.param_names) - - def genParamsFromDetector(self, detector): - if self.classed_params: - classed_dict = dict((classed_param, get_names( - detector.cv_results_["param_" + classed_param])) - for classed_param in self.classed_params) - if self.param_names: - return [(param_name, - np.array(detector.cv_results_["param_" + param_name])) - if param_name not in self.classed_params else ( - param_name, classed_dict[param_name]) - for param_name in self.param_names] - else: - return [()] - - def genDistribs(self): - return dict((param_name, distrib) for param_name, distrib in - zip(self.param_names, self.distribs)) - - def getConfig(self): - if self.param_names: - return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join( - [param_name + " : " + self.to_str(param_name) for param_name in - self.param_names]) - else: - return "\n\t\t- " + self.__class__.__name__ + "with no config." - - def to_str(self, param_name): - if param_name in self.weird_strings: - if self.weird_strings[param_name] == "class_name": - return self.get_params()[param_name].__class__.__name__ - else: - return self.weird_strings[param_name]( - self.get_params()[param_name]) - else: - return str(self.get_params()[param_name]) - - def get_interpretation(self): - return "No detailed interpretation function" - - -def get_train_views_indices(dataset, train_indices, view_indices, ): - """This function is used to get all the examples indices and view indices if needed""" - if view_indices is None: - view_indices = np.arange(dataset.nb_view) - if train_indices is None: - train_indices = range(dataset.get_nb_examples()) - return train_indices, view_indices +# +# def get_train_views_indices(dataset, train_indices, view_indices, ): +# """This function is used to get all the examples indices and view indices if needed""" +# if view_indices is None: +# view_indices = np.arange(dataset.nb_view) +# if train_indices is None: +# train_indices = range(dataset.get_nb_examples()) +# return train_indices, view_indices diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py index 5c5ae1c2d29d9a696c56baa5fef0a713aaeecdbc..b1cd5f9e6ea962cbffdbf5fa98bfea6e092ce9c0 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py @@ -23,7 +23,7 @@ class BayesianInferenceClassifier(LateFusionClassifier): example_indices, view_indices = get_examples_views_indices(X, example_indices, view_indices) - + self._check_views(view_indices) if sum(self.weights) != 1.0: self.weights = self.weights / sum(self.weights) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py index 23c102b655297d0a68f8aed6309da6eda51206c0..53a255c764f79c8e68271caba38539dea019c774 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py @@ -23,16 +23,16 @@ class MajorityVoting(LateFusionClassifier): rs=rs) def predict(self, X, example_indices=None, view_indices=None): - examples_indices, views_indices = get_examples_views_indices(X, + examples_indices, view_indices = get_examples_views_indices(X, example_indices, view_indices) - + self._check_views(view_indices) n_examples = len(examples_indices) votes = np.zeros((n_examples, X.get_nb_class(example_indices)), dtype=float) monoview_decisions = np.zeros((len(examples_indices), X.nb_view), dtype=int) - for index, view_index in enumerate(views_indices): + for index, view_index in enumerate(view_indices): monoview_decisions[:, index] = self.monoview_estimators[ index].predict( X.get_v(view_index, examples_indices)) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py index 83b4c555721fc647458f1e36cbe478284b80e2c4..eaa8ce34e2ffd6ab4e811638d94d1211b5c94faf 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py @@ -69,6 +69,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): def fit(self, X, y, train_indices=None, view_indices=None): train_indices, X = self.transform_data_to_monoview(X, train_indices, view_indices) + self.used_views = view_indices if np.unique(y[train_indices]).shape[0] > 2 and \ not (isinstance(self.monoview_classifier, MultiClassWrapper)): self.monoview_classifier = get_mc_estim(self.monoview_classifier, @@ -81,6 +82,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): def predict(self, X, example_indices=None, view_indices=None): _, X = self.transform_data_to_monoview(X, example_indices, view_indices) + self._check_views(self.view_indices) predicted_labels = self.monoview_classifier.predict(X) return predicted_labels diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py index 32f4a71033d8d0c2e82804a45fc7c622e5b51598..403791ceec03ef3c18e9152a996bc5a39d41bd54 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py @@ -17,11 +17,12 @@ class WeightedLinearLateFusion(LateFusionClassifier): nb_cores=nb_cores, weights=weights, rs=rs) def predict(self, X, example_indices=None, view_indices=None): - example_indices, views_indices = get_examples_views_indices(X, + example_indices, view_indices = get_examples_views_indices(X, example_indices, view_indices) + self._check_views(view_indices) view_scores = [] - for index, viewIndex in enumerate(views_indices): + for index, viewIndex in enumerate(view_indices): view_scores.append( np.array(self.monoview_estimators[index].predict_proba( X.get_v(viewIndex, example_indices))) * self.weights[index])