From 8b7fe3664cddb8778b9b1a35b718730fd98b0466 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Wed, 25 Mar 2020 00:36:53 +0100
Subject: [PATCH] Fix _evaluate_predictions for SOTA classifiers

---
 code/bolsonaro/trainer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 21b7ab1..1d06dbe 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -155,7 +155,7 @@ class Trainer(object):
 
     def _evaluate_predictions(self, model, X, aggregation_function):
         if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
-            OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+            OmpForestBinaryClassifier, OmpForestMulticlassClassifier, SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
             estimators = model.forest
             estimators = np.asarray(estimators)[model._omp.coef_ != 0]
         elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
-- 
GitLab