diff --git a/code/bolsonaro/models/nn_omp_forest_classifier.py b/code/bolsonaro/models/nn_omp_forest_classifier.py index 1279b7ad2874b7cdd4503859c7d332541d447a91..f33cb0691f448c4522d2da1ab29499057436c04a 100644 --- a/code/bolsonaro/models/nn_omp_forest_classifier.py +++ b/code/bolsonaro/models/nn_omp_forest_classifier.py @@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier): :param X: some data to apply the forest to :return: a np.array of the predictions of the trees selected by OMP without applying the weight """ - forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]) + forest_predictions = self._base_estimator_predictions(X) if forest_size is not None: weights = self._omp.get_coef(forest_size) - select_trees = np.mean(forest_predictions[weights != 0], axis=0) + select_trees = np.mean(forest_predictions[:, weights != 0], axis=1) return select_trees else: lst_predictions = []