From 347b7f5db8fc7d965ab0c668a3804f7bcc9ef2f3 Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Fri, 13 Mar 2020 16:27:03 +0100 Subject: [PATCH] Fix omp predictions in coherence --- code/bolsonaro/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index c2bd767..9d4911e 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -145,16 +145,14 @@ class Trainer(object): from sklearn.preprocessing import normalize import itertools - if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]: - estimators = model.forest - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: + if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, + OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: estimators = model.forest + estimators = np.asarray(estimators)[model._omp.coef_ != 0] elif type(model) in [RandomForestRegressor, RandomForestClassifier]: estimators = model.estimators_ - predictions = list() - for ti in estimators: - predictions.append(ti.predict(X)) + predictions = np.array([tree.predict(X) for tree in estimators]) predictions = normalize(predictions) -- GitLab