From c616064687f8c82df631be1462b02c8992fd9855 Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Wed, 25 Mar 2020 00:46:08 +0100 Subject: [PATCH] Fix _evaluate_predictions estimators setting for SOTA methods using new selected_trees attribute --- code/bolsonaro/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 1d06dbe..df6f024 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -154,10 +154,12 @@ class Trainer(object): return result def _evaluate_predictions(self, model, X, aggregation_function): - if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, - OmpForestBinaryClassifier, OmpForestMulticlassClassifier, SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: + if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: estimators = model.forest estimators = np.asarray(estimators)[model._omp.coef_ != 0] + elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, + SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: + estimators = model.selected_trees elif type(model) in [RandomForestRegressor, RandomForestClassifier]: estimators = model.estimators_ -- GitLab