Commit 34bca5fe authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Argmax instead of mean for predict in binary

parent 72465f53
......@@ -24,6 +24,27 @@ class OmpForestBinaryClassifier(SingleOmpForest):
return super().fit(X_forest, y_forest, X_omp, y_omp)
def predict_no_weights(self, X):
"""
Apply the SingleOmpForest to X without using the weights.
Make all the base tree predictions
:param X: a Forest
:return: a np.array of the predictions of the entire forest
"""
forest_predictions = self._base_estimator_predictions(X).T
if self._models_parameters.normalize_D:
forest_predictions /= self._forest_norms
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)
select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0)
return select_trees
def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
"""
Evaluate OMPForestClassifer on (`X`, `y`) using `metric`
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment