Commit af068a00 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Better to take non zero values of list as indicated in the numpy doc

parent 2d896dd1
......@@ -136,14 +136,11 @@ class SingleOmpForest(OmpForest):
:param X: a Forest
:return: a np.array of the predictions of the entire forest
"""
forest_predictions = self._base_estimator_predictions(X).T
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
if self._models_parameters.normalize_D:
forest_predictions /= self._forest_norms
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)[0]
select_trees = np.mean(forest_predictions[omp_trees_indices], axis=0)
print(len(omp_trees_indices))
select_trees = np.mean(forest_predictions[weights != 0], axis=0)
return select_trees
......@@ -40,9 +40,7 @@ class OmpForestBinaryClassifier(SingleOmpForest):
forest_predictions /= self._forest_norms
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)
omp_trees_predictions = forest_predictions[omp_trees_indices].T[1]
omp_trees_predictions = forest_predictions[weights != 0].T[1]
# Here forest_pred is the probability of being class 1.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment