From 65b27dd5e9df2d93554006136e12fd5f3de31f23 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 27 Feb 2020 20:10:36 +0100
Subject: [PATCH] Added check view in develop

---
 .../multiview/multiview_utils.py              |   7 +-
 .../additions/diversity_utils.py              |   2 +
 .../additions/jumbo_fusion_utils.py           |   2 +
 .../additions/late_fusion_utils.py            |   1 +
 .../multiview_classifiers/additions/utils.py  | 110 +++++++++---------
 .../bayesian_inference_fusion.py              |   2 +-
 .../majority_voting_fusion.py                 |   6 +-
 .../weighted_linear_early_fusion.py           |   2 +
 .../weighted_linear_late_fusion.py            |   5 +-
 9 files changed, 75 insertions(+), 62 deletions(-)

diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
index 3d78d991..31e2546b 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
@@ -4,7 +4,7 @@ import numpy as np
 
 from .. import monoview_classifiers
 from ..utils.base import BaseClassifier, ResultAnalyser
-from ..utils.dataset import RAMDataset
+from ..utils.dataset import RAMDataset, get_examples_views_indices
 
 
 class FakeEstimator():
@@ -29,6 +29,7 @@ class BaseMultiviewClassifier(BaseClassifier):
         self.random_state = random_state
         self.short_name = self.__module__.split(".")[-1]
         self.weird_strings = {}
+        self.used_views = None
 
     @abstractmethod
     def fit(self, X, y, train_indices=None, view_indices=None):
@@ -38,6 +39,10 @@ class BaseMultiviewClassifier(BaseClassifier):
     def predict(self, X, example_indices=None, view_indices=None):
         pass
 
+    def _check_views(self, view_indices):
+        if self.used_views is not None and not np.array_equal(np.sort(self.used_views), np.sort(view_indices)):
+            raise ValueError('Used {} views to fit, and trying to predict on {}'.format(self.used_views, view_indices))
+
     def to_str(self, param_name):
         if param_name in self.weird_strings:
             string = ""
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
index 05e4cd05..a4984519 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
@@ -30,6 +30,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier,
         train_indices, view_indices = get_examples_views_indices(X,
                                                                  train_indices,
                                                                  view_indices)
+        self.used_views = view_indices
         # TODO : Finer analysis, may support a bit of mutliclass
         if np.unique(y[train_indices]).shape[0] > 2:
             raise ValueError(
@@ -56,6 +57,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier,
         example_indices, view_indices = get_examples_views_indices(X,
                                                                    example_indices,
                                                                    view_indices)
+        self._check_views(view_indices)
         nb_class = X.get_nb_class()
         if nb_class > 2:
             nb_class = 3
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py
index f657c6c2..e9cbac4c 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py
@@ -27,6 +27,7 @@ class BaseJumboFusion(LateFusionClassifier):
         example_indices, view_indices = get_examples_views_indices(X,
                                                                    example_indices,
                                                                    view_indices)
+        self._check_views(view_indices)
         monoview_decisions = self.predict_monoview(X,
                                                    example_indices=example_indices,
                                                    view_indices=view_indices)
@@ -36,6 +37,7 @@ class BaseJumboFusion(LateFusionClassifier):
         train_indices, view_indices = get_examples_views_indices(X,
                                                                  train_indices,
                                                                  view_indices)
+        self.used_views = view_indices
         self.init_classifiers(len(view_indices),
                               nb_monoview_per_view=self.nb_monoview_per_view)
         self.fit_monoview_estimators(X, y, train_indices=train_indices,
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py
index e2e8da5d..d6ff8b4c 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/late_fusion_utils.py
@@ -97,6 +97,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
         train_indices, view_indices = get_examples_views_indices(X,
                                                                  train_indices,
                                                                  view_indices)
+        self.used_views = view_indices
         if np.unique(y).shape[0] > 2:
             multiclass = True
         else:
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py
index 6aa2a7be..5fbd4d56 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/utils.py
@@ -6,59 +6,59 @@ def get_names(classed_list):
     return np.array([object_.__class__.__name__ for object_ in classed_list])
 
 
-class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin):
+# class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin):
+#
+#     def __init__(self, random_state):
+#         self.random_state = random_state
+#
+#     def genBestParams(self, detector):
+#         return dict((param_name, detector.best_params_[param_name])
+#                     for param_name in self.param_names)
+#
+#     def genParamsFromDetector(self, detector):
+#         if self.classed_params:
+#             classed_dict = dict((classed_param, get_names(
+#                 detector.cv_results_["param_" + classed_param]))
+#                                 for classed_param in self.classed_params)
+#         if self.param_names:
+#             return [(param_name,
+#                      np.array(detector.cv_results_["param_" + param_name]))
+#                     if param_name not in self.classed_params else (
+#                 param_name, classed_dict[param_name])
+#                     for param_name in self.param_names]
+#         else:
+#             return [()]
+#
+#     def genDistribs(self):
+#         return dict((param_name, distrib) for param_name, distrib in
+#                     zip(self.param_names, self.distribs))
+#
+#     def getConfig(self):
+#         if self.param_names:
+#             return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join(
+#                 [param_name + " : " + self.to_str(param_name) for param_name in
+#                  self.param_names])
+#         else:
+#             return "\n\t\t- " + self.__class__.__name__ + "with no config."
+#
+#     def to_str(self, param_name):
+#         if param_name in self.weird_strings:
+#             if self.weird_strings[param_name] == "class_name":
+#                 return self.get_params()[param_name].__class__.__name__
+#             else:
+#                 return self.weird_strings[param_name](
+#                     self.get_params()[param_name])
+#         else:
+#             return str(self.get_params()[param_name])
+#
+#     def get_interpretation(self):
+#         return "No detailed interpretation function"
 
-    def __init__(self, random_state):
-        self.random_state = random_state
-
-    def genBestParams(self, detector):
-        return dict((param_name, detector.best_params_[param_name])
-                    for param_name in self.param_names)
-
-    def genParamsFromDetector(self, detector):
-        if self.classed_params:
-            classed_dict = dict((classed_param, get_names(
-                detector.cv_results_["param_" + classed_param]))
-                                for classed_param in self.classed_params)
-        if self.param_names:
-            return [(param_name,
-                     np.array(detector.cv_results_["param_" + param_name]))
-                    if param_name not in self.classed_params else (
-                param_name, classed_dict[param_name])
-                    for param_name in self.param_names]
-        else:
-            return [()]
-
-    def genDistribs(self):
-        return dict((param_name, distrib) for param_name, distrib in
-                    zip(self.param_names, self.distribs))
-
-    def getConfig(self):
-        if self.param_names:
-            return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join(
-                [param_name + " : " + self.to_str(param_name) for param_name in
-                 self.param_names])
-        else:
-            return "\n\t\t- " + self.__class__.__name__ + "with no config."
-
-    def to_str(self, param_name):
-        if param_name in self.weird_strings:
-            if self.weird_strings[param_name] == "class_name":
-                return self.get_params()[param_name].__class__.__name__
-            else:
-                return self.weird_strings[param_name](
-                    self.get_params()[param_name])
-        else:
-            return str(self.get_params()[param_name])
-
-    def get_interpretation(self):
-        return "No detailed interpretation function"
-
-
-def get_train_views_indices(dataset, train_indices, view_indices, ):
-    """This function  is used to get all the examples indices and view indices if needed"""
-    if view_indices is None:
-        view_indices = np.arange(dataset.nb_view)
-    if train_indices is None:
-        train_indices = range(dataset.get_nb_examples())
-    return train_indices, view_indices
+#
+# def get_train_views_indices(dataset, train_indices, view_indices, ):
+#     """This function  is used to get all the examples indices and view indices if needed"""
+#     if view_indices is None:
+#         view_indices = np.arange(dataset.nb_view)
+#     if train_indices is None:
+#         train_indices = range(dataset.get_nb_examples())
+#     return train_indices, view_indices
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py
index 5c5ae1c2..b1cd5f9e 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/bayesian_inference_fusion.py
@@ -23,7 +23,7 @@ class BayesianInferenceClassifier(LateFusionClassifier):
         example_indices, view_indices = get_examples_views_indices(X,
                                                                    example_indices,
                                                                    view_indices)
-
+        self._check_views(view_indices)
         if sum(self.weights) != 1.0:
             self.weights = self.weights / sum(self.weights)
 
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py
index 23c102b6..53a255c7 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/majority_voting_fusion.py
@@ -23,16 +23,16 @@ class MajorityVoting(LateFusionClassifier):
                                       rs=rs)
 
     def predict(self, X, example_indices=None, view_indices=None):
-        examples_indices, views_indices = get_examples_views_indices(X,
+        examples_indices, view_indices = get_examples_views_indices(X,
                                                                      example_indices,
                                                                      view_indices)
-
+        self._check_views(view_indices)
         n_examples = len(examples_indices)
         votes = np.zeros((n_examples, X.get_nb_class(example_indices)),
                          dtype=float)
         monoview_decisions = np.zeros((len(examples_indices), X.nb_view),
                                       dtype=int)
-        for index, view_index in enumerate(views_indices):
+        for index, view_index in enumerate(view_indices):
             monoview_decisions[:, index] = self.monoview_estimators[
                 index].predict(
                 X.get_v(view_index, examples_indices))
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
index 83b4c555..eaa8ce34 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
@@ -69,6 +69,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
     def fit(self, X, y, train_indices=None, view_indices=None):
         train_indices, X = self.transform_data_to_monoview(X, train_indices,
                                                            view_indices)
+        self.used_views = view_indices
         if np.unique(y[train_indices]).shape[0] > 2 and \
                 not (isinstance(self.monoview_classifier, MultiClassWrapper)):
             self.monoview_classifier = get_mc_estim(self.monoview_classifier,
@@ -81,6 +82,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
 
     def predict(self, X, example_indices=None, view_indices=None):
         _, X = self.transform_data_to_monoview(X, example_indices, view_indices)
+        self._check_views(self.view_indices)
         predicted_labels = self.monoview_classifier.predict(X)
         return predicted_labels
 
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py
index 32f4a710..403791ce 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_late_fusion.py
@@ -17,11 +17,12 @@ class WeightedLinearLateFusion(LateFusionClassifier):
                                       nb_cores=nb_cores, weights=weights, rs=rs)
 
     def predict(self, X, example_indices=None, view_indices=None):
-        example_indices, views_indices = get_examples_views_indices(X,
+        example_indices, view_indices = get_examples_views_indices(X,
                                                                     example_indices,
                                                                     view_indices)
+        self._check_views(view_indices)
         view_scores = []
-        for index, viewIndex in enumerate(views_indices):
+        for index, viewIndex in enumerate(view_indices):
             view_scores.append(
                 np.array(self.monoview_estimators[index].predict_proba(
                     X.get_v(viewIndex, example_indices))) * self.weights[index])
-- 
GitLab