Skip to content
Snippets Groups Projects
Commit fafad9a3 authored by Luc Giffon's avatar Luc Giffon
Browse files

fix bug predict base estimator wo weights classif

parent bbe3c3d7
No related branches found
1 merge request!24Resolve "non negative omp"
...@@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier): ...@@ -36,11 +36,11 @@ class NonNegativeOmpForestBinaryClassifier(OmpForestBinaryClassifier):
:param X: some data to apply the forest to :param X: some data to apply the forest to
:return: a np.array of the predictions of the trees selected by OMP without applying the weight :return: a np.array of the predictions of the trees selected by OMP without applying the weight
""" """
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]) forest_predictions = self._base_estimator_predictions(X)
if forest_size is not None: if forest_size is not None:
weights = self._omp.get_coef(forest_size) weights = self._omp.get_coef(forest_size)
select_trees = np.mean(forest_predictions[weights != 0], axis=0) select_trees = np.mean(forest_predictions[:, weights != 0], axis=1)
return select_trees return select_trees
else: else:
lst_predictions = [] lst_predictions = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment