diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index 26d9f6af3cb4472ac835e7b6ebbfb2b4f146ffe3..a86e53ba862a99379e92a9dc9ca6a688178164c5 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -33,7 +33,8 @@ class OmpForestBinaryClassifier(SingleOmpForest): :param X: a Forest :return: a np.array of the predictions of the entire forest """ - forest_predictions = self._base_estimator_predictions(X).T + + forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]) if self._models_parameters.normalize_D: forest_predictions /= self._forest_norms @@ -41,9 +42,17 @@ class OmpForestBinaryClassifier(SingleOmpForest): weights = self._omp.coef_ omp_trees_indices = np.nonzero(weights) - select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0) + omp_trees_predictions = forest_predictions[omp_trees_indices].T[1] + + # Here forest_pred is the probability of being class 1. + + result_omp = np.mean(omp_trees_predictions, axis=1) + + result_omp = (result_omp - 0.5) * 2 + + print(result_omp) - return select_trees + return result_omp def score(self, X, y, metric=DEFAULT_SCORE_METRIC): """