From 34bca5fedd2f9558dc5d2052001690d8405fd6ed Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Thu, 5 Mar 2020 12:00:28 +0100
Subject: [PATCH] Argmax instead of mean for predict in binary

---
 .../bolsonaro/models/omp_forest_classifier.py | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index ccaf3eb..26d9f6a 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -24,6 +24,27 @@ class OmpForestBinaryClassifier(SingleOmpForest):
 
         return super().fit(X_forest, y_forest, X_omp, y_omp)
 
+    def predict_no_weights(self, X):
+        """
+        Apply the SingleOmpForest to X without using the weights.
+
+        Make all the base tree predictions
+
+        :param X: a Forest
+        :return: a np.array of the predictions of the entire forest
+        """
+        forest_predictions = self._base_estimator_predictions(X).T
+
+        if self._models_parameters.normalize_D:
+            forest_predictions /= self._forest_norms
+
+        weights = self._omp.coef_
+        omp_trees_indices = np.nonzero(weights)
+
+        select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0)
+
+        return select_trees
+
     def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
         """
         Evaluate OMPForestClassifer on (`X`, `y`) using `metric`
-- 
GitLab