From fafad9a397585a5498163a6cf43a469e8a8bd6bb Mon Sep 17 00:00:00 2001 From: Luc Giffon <luc.giffon@lis-lab.fr> Date: Sun, 29 Mar 2020 20:13:05 +0200 Subject: [PATCH] fix bug predict base estimator wo weights classif --- code/bolsonaro/models/nn_omp_forest_classifier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/bolsonaro/models/nn_omp_forest_classifier.py b/code/bolsonaro/models/nn_omp_forest_classifier.py index 1279b7a..f33cb06 100644 --- a/code/bolsonaro/models/nn_omp_forest_classifier.py +++ b/code/bolsonaro/models/nn_omp_forest_classifier.py @@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier): :param X: some data to apply the forest to :return: a np.array of the predictions of the trees selected by OMP without applying the weight """ - forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]) + forest_predictions = self._base_estimator_predictions(X) if forest_size is not None: weights = self._omp.get_coef(forest_size) - select_trees = np.mean(forest_predictions[weights != 0], axis=0) + select_trees = np.mean(forest_predictions[:, weights != 0], axis=1) return select_trees else: lst_predictions = [] -- GitLab