From 7dd2aab5c9583fc129b5ae43ec76cfc9846cf209 Mon Sep 17 00:00:00 2001 From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr> Date: Fri, 8 Nov 2019 15:58:14 +0100 Subject: [PATCH] Fix normalize_weights impl --- code/bolsonaro/models/omp_forest_regressor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/code/bolsonaro/models/omp_forest_regressor.py b/code/bolsonaro/models/omp_forest_regressor.py index b4896cd..b9abfa5 100644 --- a/code/bolsonaro/models/omp_forest_regressor.py +++ b/code/bolsonaro/models/omp_forest_regressor.py @@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator): :param X: :return: """ - D = self._forest_prediction(X) + forest_predictions = self._forest_prediction(X) if self._models_parameters.normalize_D: - D /= self._forest_norms + forest_predictions /= self._forest_norms - # TODO: use self._models_parameters.normalize_weights here? - predictions = self._omp.predict(D) + predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_)))) \ + if self._models_parameters.normalize_weights \ + else self._omp.predict(forest_predictions) return predictions -- GitLab