From 7dd2aab5c9583fc129b5ae43ec76cfc9846cf209 Mon Sep 17 00:00:00 2001
From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr>
Date: Fri, 8 Nov 2019 15:58:14 +0100
Subject: [PATCH] Fix normalize_weights impl

---
 code/bolsonaro/models/omp_forest_regressor.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/code/bolsonaro/models/omp_forest_regressor.py b/code/bolsonaro/models/omp_forest_regressor.py
index b4896cd..b9abfa5 100644
--- a/code/bolsonaro/models/omp_forest_regressor.py
+++ b/code/bolsonaro/models/omp_forest_regressor.py
@@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator):
         :param X:
         :return:
         """
-        D = self._forest_prediction(X)
+        forest_predictions = self._forest_prediction(X)
 
         if self._models_parameters.normalize_D:
-            D /= self._forest_norms
+            forest_predictions /= self._forest_norms
 
-        # TODO: use self._models_parameters.normalize_weights here?
-        predictions = self._omp.predict(D)
+        predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_)))) \
+            if self._models_parameters.normalize_weights \
+            else self._omp.predict(forest_predictions)
 
         return predictions
 
-- 
GitLab