Skip to content
Snippets Groups Projects
Commit c43ae6c4 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add few stage5 results, code fixes and begins to caught warnings

parent 5d348f2e
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
del class_list[candidate_index]
def score(self, X, y):
predictions = self._predict_base_estimator(X)
predictions = self.predict_base_estimator(X)
return self._score_metric(predictions, y)
def predict_base_estimator(self, X):
......
......@@ -5,6 +5,7 @@ from abc import abstractmethod, ABCMeta
import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.base import BaseEstimator
import warnings
class OmpForest(BaseEstimator, metaclass=ABCMeta):
......@@ -109,7 +110,17 @@ class SingleOmpForest(OmpForest):
super().__init__(models_parameters, base_forest_estimator)
def fit_omp(self, atoms, objective):
self._omp.fit(atoms, objective)
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
self._omp.fit(atoms, objective)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
def predict(self, X):
"""
......
from bolsonaro.models.omp_forest import OmpForest, SingleOmpForest
from bolsonaro.utils import binarize_class_data
from bolsonaro.utils import binarize_class_data, omp_premature_warning
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import OrthogonalMatchingPursuit
import warnings
class OmpForestBinaryClassifier(SingleOmpForest):
......@@ -92,7 +93,19 @@ class OmpForestMulticlassClassifier(OmpForest):
omp_class = OrthogonalMatchingPursuit(
n_nonzero_coefs=self.models_parameters.extracted_forest_size,
fit_intercept=True, normalize=False)
omp_class.fit(atoms_binary, objective_binary)
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
omp_class.fit(atoms_binary, objective_binary)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
self._dct_class_omp[class_label] = omp_class
return self._dct_class_omp
......
......@@ -124,3 +124,7 @@ def is_float(value):
return True
except ValueError:
return False
omp_premature_warning = """ Orthogonal matching pursuit ended prematurely due to linear
dependence in the dictionary. The requested precision might not have been met.
"""
results/boston/stage5/losses_base-random-omp-omp_without_weights-kmeans.png

64 KiB

results/diabetes/stage5/losses_base-random-omp-omp_without_weights-kmeans.png

63.2 KiB

results/diabetes/stage5/losses_base-random-omp-omp_without_weights-similarity.png

57.9 KiB

results/diamonds/stage5/losses_base-random-omp-omp_without_weights-kmeans.png

65.9 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment