From 347b7f5db8fc7d965ab0c668a3804f7bcc9ef2f3 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 13 Mar 2020 16:27:03 +0100
Subject: [PATCH] Fix omp predictions in coherence

---
 code/bolsonaro/trainer.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index c2bd767..9d4911e 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -145,16 +145,14 @@ class Trainer(object):
         from sklearn.preprocessing import normalize
         import itertools
 
-        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
-            estimators = model.forest
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
+            OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
             estimators = model.forest
+            estimators = np.asarray(estimators)[model._omp.coef_ != 0]
         elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
             estimators = model.estimators_
 
-        predictions = list()
-        for ti in estimators:
-            predictions.append(ti.predict(X))
+        predictions = np.array([tree.predict(X) for tree in estimators])
 
         predictions = normalize(predictions)
 
-- 
GitLab