From 34bca5fedd2f9558dc5d2052001690d8405fd6ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu> Date: Thu, 5 Mar 2020 12:00:28 +0100 Subject: [PATCH] Argmax instead of mean for predict in binary --- .../bolsonaro/models/omp_forest_classifier.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index ccaf3eb..26d9f6a 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -24,6 +24,27 @@ class OmpForestBinaryClassifier(SingleOmpForest): return super().fit(X_forest, y_forest, X_omp, y_omp) + def predict_no_weights(self, X): + """ + Apply the SingleOmpForest to X without using the weights. + + Make all the base tree predictions + + :param X: a Forest + :return: a np.array of the predictions of the entire forest + """ + forest_predictions = self._base_estimator_predictions(X).T + + if self._models_parameters.normalize_D: + forest_predictions /= self._forest_norms + + weights = self._omp.coef_ + omp_trees_indices = np.nonzero(weights) + + select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0) + + return select_trees + def score(self, X, y, metric=DEFAULT_SCORE_METRIC): """ Evaluate OMPForestClassifer on (`X`, `y`) using `metric` -- GitLab