diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index 5b211b0d2900a42ac5fedf95d1c98bee4f1fa5e9..5b947d327693020b51c7da778d4855274454de93 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -142,8 +142,8 @@ class SingleOmpForest(OmpForest): forest_predictions /= self._forest_norms weights = self._omp.coef_ - omp_trees_indices = np.nonzero(weights) + omp_trees_indices = np.nonzero(weights)[0] select_trees = np.mean(forest_predictions[omp_trees_indices], axis=0) - + print(len(omp_trees_indices)) return select_trees