Commit c43ae6c4 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add few stage5 results, code fixes and begins to caught warnings

parent 5d348f2e
...@@ -52,7 +52,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta): ...@@ -52,7 +52,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
del class_list[candidate_index] del class_list[candidate_index]
def score(self, X, y): def score(self, X, y):
predictions = self._predict_base_estimator(X) predictions = self.predict_base_estimator(X)
return self._score_metric(predictions, y) return self._score_metric(predictions, y)
def predict_base_estimator(self, X): def predict_base_estimator(self, X):
......
...@@ -5,6 +5,7 @@ from abc import abstractmethod, ABCMeta ...@@ -5,6 +5,7 @@ from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
import warnings
class OmpForest(BaseEstimator, metaclass=ABCMeta): class OmpForest(BaseEstimator, metaclass=ABCMeta):
...@@ -109,7 +110,17 @@ class SingleOmpForest(OmpForest): ...@@ -109,7 +110,17 @@ class SingleOmpForest(OmpForest):
super().__init__(models_parameters, base_forest_estimator) super().__init__(models_parameters, base_forest_estimator)
def fit_omp(self, atoms, objective): def fit_omp(self, atoms, objective):
self._omp.fit(atoms, objective) with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
self._omp.fit(atoms, objective)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
def predict(self, X): def predict(self, X):
""" """
......
from bolsonaro.models.omp_forest import OmpForest, SingleOmpForest from bolsonaro.models.omp_forest import OmpForest, SingleOmpForest
from bolsonaro.utils import binarize_class_data from bolsonaro.utils import binarize_class_data, omp_premature_warning
import numpy as np import numpy as np
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import OrthogonalMatchingPursuit from sklearn.linear_model import OrthogonalMatchingPursuit
import warnings
class OmpForestBinaryClassifier(SingleOmpForest): class OmpForestBinaryClassifier(SingleOmpForest):
...@@ -92,7 +93,19 @@ class OmpForestMulticlassClassifier(OmpForest): ...@@ -92,7 +93,19 @@ class OmpForestMulticlassClassifier(OmpForest):
omp_class = OrthogonalMatchingPursuit( omp_class = OrthogonalMatchingPursuit(
n_nonzero_coefs=self.models_parameters.extracted_forest_size, n_nonzero_coefs=self.models_parameters.extracted_forest_size,
fit_intercept=True, normalize=False) fit_intercept=True, normalize=False)
omp_class.fit(atoms_binary, objective_binary)
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
omp_class.fit(atoms_binary, objective_binary)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
self._dct_class_omp[class_label] = omp_class self._dct_class_omp[class_label] = omp_class
return self._dct_class_omp return self._dct_class_omp
......
...@@ -124,3 +124,7 @@ def is_float(value): ...@@ -124,3 +124,7 @@ def is_float(value):
return True return True
except ValueError: except ValueError:
return False return False
omp_premature_warning = """ Orthogonal matching pursuit ended prematurely due to linear
dependence in the dictionary. The requested precision might not have been met.
"""
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment