From fafad9a397585a5498163a6cf43a469e8a8bd6bb Mon Sep 17 00:00:00 2001
From: Luc Giffon <luc.giffon@lis-lab.fr>
Date: Sun, 29 Mar 2020 20:13:05 +0200
Subject: [PATCH] fix bug predict base estimator wo weights classif

---
 code/bolsonaro/models/nn_omp_forest_classifier.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/code/bolsonaro/models/nn_omp_forest_classifier.py b/code/bolsonaro/models/nn_omp_forest_classifier.py
index 1279b7a..f33cb06 100644
--- a/code/bolsonaro/models/nn_omp_forest_classifier.py
+++ b/code/bolsonaro/models/nn_omp_forest_classifier.py
@@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
         :param X: some data to apply the forest to
         :return: a np.array of the predictions of the trees selected by OMP without applying the weight
         """
-        forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
+        forest_predictions = self._base_estimator_predictions(X)
 
         if forest_size is not None:
             weights = self._omp.get_coef(forest_size)
-            select_trees = np.mean(forest_predictions[weights != 0], axis=0)
+            select_trees = np.mean(forest_predictions[:, weights != 0], axis=1)
             return select_trees
         else:
             lst_predictions = []
-- 
GitLab