From c678317301e37cb8c926120b06702605b1257fdf Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 2 Oct 2019 09:01:33 -0400
Subject: [PATCH] Corrected difficulty score

---
 .../additions/diversity_utils.py                 |  2 +-
 .../multiview_classifiers/difficulty_fusion.py   | 16 +++++++---------
 .../test_difficulty_fusion.py                    |  2 +-
 3 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
index a72d82aa..98956a5a 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/diversity_utils.py
@@ -49,7 +49,7 @@ class DiversityFusion(BaseMultiviewClassifier):
                     estimator.fit(X.get_v(view_idx, train_indices), y[train_indices])
                     self.estimator_pool[classifier_idx].append(estimator)
         else:
-            pass #Todo
+            pass #TODO
         self.monoview_estimators = self.choose_combination(X, y, train_indices, views_indices)
         return self
 
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/difficulty_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/difficulty_fusion.py
index 4e454377..dbdc97de 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/difficulty_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/difficulty_fusion.py
@@ -9,24 +9,22 @@ classifier_class_name = "DifficultyFusion"
 class DifficultyFusion(GlobalDiversityFusion):
 
     def diversity_measure(self, classifiers_decisions, combination, y):
-
         _, nb_view, nb_examples = classifiers_decisions.shape
         scores = np.zeros((nb_view, nb_examples), dtype=int)
         for view_index, classifier_index in enumerate(combination):
-            scores[view_index] = np.logical_not(
+            scores[view_index, :] = np.logical_not(
                     np.logical_xor(classifiers_decisions[classifier_index,
                                                          view_index],
                                    y)
                 )
-        # difficulty_scores = np.sum(scores, axis=0)
-        # TODO : Check computing method
-        difficulty_score = np.mean(
-            np.var(
+        # Table of the nuber of views that succeeded for each example :
+        difficulty_scores = np.sum(scores, axis=0)
+
+        difficulty_score = np.var(
                 np.array([
-                             np.sum((scores==view_index), axis=1)/float(nb_view)
+                             np.sum((difficulty_scores == view_index))
                              for view_index in range(len(combination)+1)])
-                , axis=0)
-        )
+                )
         return difficulty_score
 
 
diff --git a/multiview_platform/tests/test_multiview_classifiers/test_difficulty_fusion.py b/multiview_platform/tests/test_multiview_classifiers/test_difficulty_fusion.py
index 4efc6054..b49eb1f6 100644
--- a/multiview_platform/tests/test_multiview_classifiers/test_difficulty_fusion.py
+++ b/multiview_platform/tests/test_multiview_classifiers/test_difficulty_fusion.py
@@ -20,4 +20,4 @@ class Test_difficulty_fusion(unittest.TestCase):
             cls.classifiers_decisions,
             cls.combination,
             cls.y)
-        cls.assertAlmostEqual(difficulty_measure, 0.22453703703703706)
+        cls.assertAlmostEqual(difficulty_measure, 0.1875)
-- 
GitLab