From c616064687f8c82df631be1462b02c8992fd9855 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Wed, 25 Mar 2020 00:46:08 +0100
Subject: [PATCH] Fix _evaluate_predictions estimators setting for SOTA methods
 using new selected_trees attribute

---
 code/bolsonaro/trainer.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 1d06dbe..df6f024 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -154,10 +154,12 @@ class Trainer(object):
         return result
 
     def _evaluate_predictions(self, model, X, aggregation_function):
-        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
-            OmpForestBinaryClassifier, OmpForestMulticlassClassifier, SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
+        if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
             estimators = model.forest
             estimators = np.asarray(estimators)[model._omp.coef_ != 0]
+        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
+            SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
+            estimators = model.selected_trees
         elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
             estimators = model.estimators_
 
-- 
GitLab