Skip to content
Snippets Groups Projects
Commit 7dd2aab5 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Fix normalize_weights impl

parent 708b1731
No related branches found
No related tags found
1 merge request!3clean scripts
......@@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator):
:param X:
:return:
"""
D = self._forest_prediction(X)
forest_predictions = self._forest_prediction(X)
if self._models_parameters.normalize_D:
D /= self._forest_norms
forest_predictions /= self._forest_norms
# TODO: use self._models_parameters.normalize_weights here?
predictions = self._omp.predict(D)
predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_)))) \
if self._models_parameters.normalize_weights \
else self._omp.predict(forest_predictions)
return predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment