Skip to content
Snippets Groups Projects
Commit 8b3a6c49 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Use OMP to compute the prediction instead of doing it manually (same results)

parent 0fce0319
No related branches found
No related tags found
1 merge request!3clean scripts
......@@ -38,7 +38,8 @@ class OmpForestRegressor(BaseEstimator):
def fit(self, X_forest, y_forest, X_omp, y_omp):
self._forest = self._train_forest(X_forest, y_forest)
self._weights = self._extract_subforest(X_omp, y_omp)
self._omp = self._extract_subforest(X_omp, y_omp)
self._weights = self._omp.coef_
return self
def score_regressor(self, X, y):
......@@ -56,8 +57,8 @@ class OmpForestRegressor(BaseEstimator):
if self._models_parameters.normalize_D:
D /= self._forest_norms
# TODO: use self._models_parameters.normalize_weights here
predictions = D @ self._weights
# TODO: use self._models_parameters.normalize_weights here?
predictions = self._omp.predict(D)
return predictions
......@@ -109,10 +110,7 @@ class OmpForestRegressor(BaseEstimator):
fit_intercept=False, normalize=False)
self._logger.debug("Apply orthogonal maching pursuit on forest for {} extracted trees."
.format(self._models_parameters.extracted_forest_size))
omp.fit(D, y)
weights = omp.coef_
# question: why not to use directly the omp estimator instead of bypassing it using the coefs?
return weights
return omp.fit(D, y)
def _forest_prediction(self, X):
return np.array([tree.predict(X) for tree in self._forest]).T
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment