Skip to content
Snippets Groups Projects
Commit 7dd2aab5 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Fix normalize_weights impl

parent 708b1731
No related branches found
No related tags found
1 merge request!3clean scripts
...@@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator): ...@@ -52,13 +52,14 @@ class OmpForestRegressor(BaseEstimator):
:param X: :param X:
:return: :return:
""" """
D = self._forest_prediction(X) forest_predictions = self._forest_prediction(X)
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
D /= self._forest_norms forest_predictions /= self._forest_norms
# TODO: use self._models_parameters.normalize_weights here? predictions = self._omp.predict(forest_predictions) * (1 / (np.sum(self._omp.coef_) / len(np.nonzero(self._omp.coef_)))) \
predictions = self._omp.predict(D) if self._models_parameters.normalize_weights \
else self._omp.predict(forest_predictions)
return predictions return predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment