From af068a00d443ee1ecd5f7e00f8d93e404aaf1bc3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Fri, 6 Mar 2020 10:49:58 +0100
Subject: [PATCH] Better to take non zero values of list as indicated in the
 numpy doc

---
 code/bolsonaro/models/omp_forest.py            | 7 ++-----
 code/bolsonaro/models/omp_forest_classifier.py | 4 +---
 2 files changed, 3 insertions(+), 8 deletions(-)

diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index 5b947d3..35c2f8f 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -136,14 +136,11 @@ class SingleOmpForest(OmpForest):
         :param X: a Forest
         :return: a np.array of the predictions of the entire forest
         """
-        forest_predictions = self._base_estimator_predictions(X).T
+        forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
 
         if self._models_parameters.normalize_D:
             forest_predictions /= self._forest_norms
 
         weights = self._omp.coef_
-        omp_trees_indices = np.nonzero(weights)[0]
-
-        select_trees = np.mean(forest_predictions[omp_trees_indices], axis=0)
-        print(len(omp_trees_indices))
+        select_trees = np.mean(forest_predictions[weights != 0], axis=0)
         return select_trees
diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index a51405a..3051fad 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -40,9 +40,7 @@ class OmpForestBinaryClassifier(SingleOmpForest):
             forest_predictions /= self._forest_norms
 
         weights = self._omp.coef_
-        omp_trees_indices = np.nonzero(weights)
-
-        omp_trees_predictions = forest_predictions[omp_trees_indices].T[1]
+        omp_trees_predictions = forest_predictions[weights != 0].T[1]
 
         # Here forest_pred is the probability of being class 1.
 
-- 
GitLab