Skip to content
Snippets Groups Projects
Commit baca1281 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction for predict_no_weights

parent 264288b4
No related branches found
No related tags found
1 merge request!15Resolve "Adding new datasets"
This commit is part of merge request !15. Comments created here will be created in the context of that merge request.
......@@ -33,7 +33,8 @@ class OmpForestBinaryClassifier(SingleOmpForest):
:param X: a Forest
:return: a np.array of the predictions of the entire forest
"""
forest_predictions = self._base_estimator_predictions(X).T
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_])
if self._models_parameters.normalize_D:
forest_predictions /= self._forest_norms
......@@ -41,9 +42,17 @@ class OmpForestBinaryClassifier(SingleOmpForest):
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)
select_trees = np.argmax(forest_predictions[omp_trees_indices], axis=0)
omp_trees_predictions = forest_predictions[omp_trees_indices].T[1]
# Here forest_pred is the probability of being class 1.
result_omp = np.mean(omp_trees_predictions, axis=1)
result_omp = (result_omp - 0.5) * 2
print(result_omp)
return select_trees
return result_omp
def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment