diff --git a/code/bolsonaro/models/omp_forest_regressor.py b/code/bolsonaro/models/omp_forest_regressor.py index b4896cda87396da32c37eccdbe7cf8c4f743976e..b9abfa5110f68907c7615c41d6f03a78e24daf25 100644 --- a/code/bolsonaro/models/omp_forest_regressor.py +++ b/code/bolsonaro/models/omp_forest_regressor.py @@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator): :param X: :return: """ - D = self._forest_prediction(X) + forest_predictions = self._forest_prediction(X) if self._models_parameters.normalize_D: - D /= self._forest_norms + forest_predictions /= self._forest_norms - # TODO: use self._models_parameters.normalize_weights here? - predictions = self._omp.predict(D) + predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_)))) \ + if self._models_parameters.normalize_weights \ + else self._omp.predict(forest_predictions) return predictions