diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 1d06dbe7291a41b8dc50f1ca7d8b719f0c197512..df6f024b0dd172ae62acdd866ce6248fb2117756 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -154,10 +154,12 @@ class Trainer(object): return result def _evaluate_predictions(self, model, X, aggregation_function): - if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, - OmpForestBinaryClassifier, OmpForestMulticlassClassifier, SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: + if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: estimators = model.forest estimators = np.asarray(estimators)[model._omp.coef_ != 0] + elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, + SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: + estimators = model.selected_trees elif type(model) in [RandomForestRegressor, RandomForestClassifier]: estimators = model.estimators_