Skip to content
Snippets Groups Projects
Commit fafad9a3 authored by Luc Giffon's avatar Luc Giffon
Browse files

fix bug predict base estimator wo weights classif

parent bbe3c3d7
No related branches found
No related tags found
1 merge request!24Resolve "non negative omp"
......@@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
:param X: some data to apply the forest to
:return: a np.array of the predictions of the trees selected by OMP without applying the weight
"""
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
forest_predictions = self._base_estimator_predictions(X)
if forest_size is not None:
weights = self._omp.get_coef(forest_size)
select_trees = np.mean(forest_predictions[weights != 0], axis=0)
select_trees = np.mean(forest_predictions[:, weights != 0], axis=1)
return select_trees
else:
lst_predictions = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment