Skip to content
Snippets Groups Projects
Commit 8b3a6c49 authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Use OMP to compute the prediction instead of doing it manually (same results)

parent 0fce0319
No related branches found
No related tags found
1 merge request!3clean scripts
...@@ -38,7 +38,8 @@ class OmpForestRegressor(BaseEstimator): ...@@ -38,7 +38,8 @@ class OmpForestRegressor(BaseEstimator):
def fit(self, X_forest, y_forest, X_omp, y_omp): def fit(self, X_forest, y_forest, X_omp, y_omp):
self._forest = self._train_forest(X_forest, y_forest) self._forest = self._train_forest(X_forest, y_forest)
self._weights = self._extract_subforest(X_omp, y_omp) self._omp = self._extract_subforest(X_omp, y_omp)
self._weights = self._omp.coef_
return self return self
def score_regressor(self, X, y): def score_regressor(self, X, y):
...@@ -56,8 +57,8 @@ class OmpForestRegressor(BaseEstimator): ...@@ -56,8 +57,8 @@ class OmpForestRegressor(BaseEstimator):
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
D /= self._forest_norms D /= self._forest_norms
# TODO: use self._models_parameters.normalize_weights here # TODO: use self._models_parameters.normalize_weights here?
predictions = D @ self._weights predictions = self._omp.predict(D)
return predictions return predictions
...@@ -109,10 +110,7 @@ class OmpForestRegressor(BaseEstimator): ...@@ -109,10 +110,7 @@ class OmpForestRegressor(BaseEstimator):
fit_intercept=False, normalize=False) fit_intercept=False, normalize=False)
self._logger.debug("Apply orthogonal maching pursuit on forest for {} extracted trees." self._logger.debug("Apply orthogonal maching pursuit on forest for {} extracted trees."
.format(self._models_parameters.extracted_forest_size)) .format(self._models_parameters.extracted_forest_size))
omp.fit(D, y) return omp.fit(D, y)
weights = omp.coef_
# question: why not to use directly the omp estimator instead of bypassing it using the coefs?
return weights
def _forest_prediction(self, X): def _forest_prediction(self, X):
return np.array([tree.predict(X) for tree in self._forest]).T return np.array([tree.predict(X) for tree in self._forest]).T
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment