From baca128104b726057e5f878817a7e99d611f6b9a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Thu, 5 Mar 2020 14:44:45 +0100
Subject: [PATCH] Correction for predict_no_weights

---
 code/bolsonaro/models/omp_forest_classifier.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index 26d9f6a..a86e53b 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -33,7 +33,8 @@ class OmpForestBinaryClassifier(SingleOmpForest):
         :param X: a Forest
         :return: a np.array of the predictions of the entire forest
         """
-        forest_predictions = self._base_estimator_predictions(X).T
+
+        forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_])
 
         if self._models_parameters.normalize_D:
             forest_predictions /= self._forest_norms
@@ -41,9 +42,17 @@ class OmpForestBinaryClassifier(SingleOmpForest):
         weights = self._omp.coef_
         omp_trees_indices = np.nonzero(weights)
 
-        select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0)
+        omp_trees_predictions = forest_predictions[omp_trees_indices].T[1]
+
+        # Here forest_pred is the probability of being class 1.
+
+        result_omp = np.mean(omp_trees_predictions, axis=1)
+
+        result_omp = (result_omp - 0.5) * 2
+
+        print(result_omp)
 
-        return select_trees
+        return result_omp
 
     def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
         """
-- 
GitLab