Skip to content
Snippets Groups Projects
Commit a98fd932 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix missing EnsembleSelectionForestClassifier in sanity checks of trainer

parent c6160646
No related branches found
1 merge request!23Resolve "integration-sota"
...@@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor ...@@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
from bolsonaro.error_handling.logger_factory import LoggerFactory from bolsonaro.error_handling.logger_factory import LoggerFactory
from bolsonaro.data.task import Task from bolsonaro.data.task import Task
from . import LOG_PATH from . import LOG_PATH
...@@ -134,7 +134,8 @@ class Trainer(object): ...@@ -134,7 +134,8 @@ class Trainer(object):
y_pred = np.sign(y_pred) y_pred = np.sign(y_pred)
y_pred = np.where(y_pred == 0, 1, y_pred) y_pred = np.where(y_pred == 0, 1, y_pred)
result = self._classification_score_metric(y_true, y_pred) result = self._classification_score_metric(y_true, y_pred)
elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]: elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier,
EnsembleSelectionForestClassifier]:
result = model.score(X, y_true) result = model.score(X, y_true)
return result return result
...@@ -142,7 +143,7 @@ class Trainer(object): ...@@ -142,7 +143,7 @@ class Trainer(object):
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]: if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = self._base_regression_score_metric(y_true, y_pred) result = self._base_regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier]: elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = self._base_classification_score_metric(y_true, y_pred) result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) == RandomForestClassifier: elif type(model) == RandomForestClassifier:
...@@ -183,7 +184,8 @@ class Trainer(object): ...@@ -183,7 +184,8 @@ class Trainer(object):
elif type(model) == OmpForestBinaryClassifier: elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp model_weights = model._omp
if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]: if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
self._selected_trees = model.selected_trees self._selected_trees = model.selected_trees
if len(self._selected_trees) > 0: if len(self._selected_trees) > 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment