diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index ccaf3ebc2b630798f62ab17a1285ab28b366ed95..26d9f6af3cb4472ac835e7b6ebbfb2b4f146ffe3 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -24,6 +24,27 @@ class OmpForestBinaryClassifier(SingleOmpForest): return super().fit(X_forest, y_forest, X_omp, y_omp) + def predict_no_weights(self, X): + """ + Apply the SingleOmpForest to X without using the weights. + + Make all the base tree predictions + + :param X: a Forest + :return: a np.array of the predictions of the entire forest + """ + forest_predictions = self._base_estimator_predictions(X).T + + if self._models_parameters.normalize_D: + forest_predictions /= self._forest_norms + + weights = self._omp.coef_ + omp_trees_indices = np.nonzero(weights) + + select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0) + + return select_trees + def score(self, X, y, metric=DEFAULT_SCORE_METRIC): """ Evaluate OMPForestClassifer on (`X`, `y`) using `metric`