From 05f4c04887e1827a84983a6fad7f364cfd3aa6a1 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Sat, 1 Dec 2018 12:36:56 -0500
Subject: [PATCH] Start worjing

---
 .../Monoview/Additions/BoostUtils.py                     | 2 +-
 .../Monoview/Additions/CQBoostUtils.py                   | 2 +-
 .../Monoview/Additions/QarBoostUtils.py                  | 9 ++++-----
 .../MonoviewClassifiers/CQBoostv21.py                    | 2 +-
 .../MonoMultiViewClassifiers/ResultAnalysis.py           | 5 ++++-
 5 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/BoostUtils.py b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/BoostUtils.py
index 3f1e4c4e..f5191941 100644
--- a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/BoostUtils.py
+++ b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/BoostUtils.py
@@ -696,7 +696,7 @@ def get_accuracy_graph(train_accuracies, classifier_name, file_name, name="Accur
 class BaseBoost(object):
 
     def __init__(self):
-        self.n_stumps_per_attribute = 1
+        self.n_stumps = 1
 
     def _collect_probas(self, X):
         return np.asarray([clf.predict_proba(X) for clf in self.estimators_generator.estimators_])
diff --git a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/CQBoostUtils.py b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/CQBoostUtils.py
index f5ba4681..c7ebf500 100644
--- a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/CQBoostUtils.py
+++ b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/CQBoostUtils.py
@@ -32,7 +32,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
         y[y == 0] = -1
 
         if self.estimators_generator is None:
-            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps_per_attribute, self_complemented=True)
+            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps, self_complemented=True)
 
         self.estimators_generator.fit(X, y)
         self.classification_matrix = self._binary_classification_matrix(X)
diff --git a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py
index f684a518..f87bfb34 100644
--- a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py
+++ b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py
@@ -36,7 +36,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
         self.divided_ponderation = divided_ponderation
         self.plotted_metric = plotted_metric
         if n_stumps_per_attribute:
-            self.n_stumps_per_attribute = n_stumps_per_attribute
+            self.n_stumps = n_stumps_per_attribute
         self.use_r = use_r
         self.printed_args_name_list = ["n_max_iterations", "self_complemented", "twice_the_same", "old_fashioned",
                                        "previous_vote_weighted", "c_bound_choice", "random_start",
@@ -58,7 +58,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
             X = np.array(X.todense())
 
         if self.estimators_generator is None:
-            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps_per_attribute,
+            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps,
                                                                    self_complemented=self.self_complemented)
         # Initialization
         y[y == 0] = -1
@@ -237,14 +237,13 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
         worst_h_index = ma.argmax(pseudo_h_values)
         return worst_h_index
 
-    def _find_best_weighted_margin(self, y_kernel_matrix):
+    def _find_best_weighted_margin(self, y_kernel_matrix, upper_bound=1.0, lower_bound=0.0):
         """Finds the new voter by choosing the one that has the best weighted margin between 0.5 and 0.55
         to avoid too god voters that will get all the votes weights"""
-        upper_bound = 0.55
         weighted_kernel_matrix = np.multiply(y_kernel_matrix, self.example_weights.reshape((self.n_total_examples, 1)))
         pseudo_h_values = ma.array(np.sum(weighted_kernel_matrix, axis=0), fill_value=-np.inf)
         pseudo_h_values[self.chosen_columns_] = ma.masked
-        acceptable_indices = np.where(np.logical_and(np.greater(upper_bound, pseudo_h_values), np.greater(pseudo_h_values, 0.5)))[0]
+        acceptable_indices = np.where(np.logical_and(np.greater(upper_bound, pseudo_h_values), np.greater(pseudo_h_values, lower_bound)))[0]
         if acceptable_indices.size > 0:
             worst_h_index = self.random_state.choice(acceptable_indices)
             return worst_h_index, [0]
diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoostv21.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoostv21.py
index f5b0ba94..9274d9dc 100644
--- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoostv21.py
+++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoostv21.py
@@ -28,7 +28,7 @@ class ColumnGenerationClassifierv21(BaseEstimator, ClassifierMixin, BaseBoost):
             X = np.array(X.todense())
 
         if self.estimators_generator is None:
-            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps_per_attribute, self_complemented=True)
+            self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps, self_complemented=True)
 
         y[y == 0] = -1
 
diff --git a/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py b/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py
index a15aa90e..c12d0b60 100644
--- a/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py
+++ b/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py
@@ -237,7 +237,10 @@ def plotMetricScores(trainScores, testScores, names, nbResults, metricName, file
     ax.set_xticks(np.arange(nbResults) + barWidth)
     ax.set_xticklabels(names, rotation="vertical")
 
-    plt.tight_layout()
+    try:
+        plt.tight_layout()
+    except:
+        pass
     f.savefig(fileName)
     plt.close()
 
-- 
GitLab