Skip to content
Snippets Groups Projects
Commit a98fd932 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix missing EnsembleSelectionForestClassifier in sanity checks of trainer

parent c6160646
No related branches found
No related tags found
1 merge request!23Resolve "integration-sota"
......@@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
from bolsonaro.error_handling.logger_factory import LoggerFactory
from bolsonaro.data.task import Task
from . import LOG_PATH
......@@ -134,7 +134,8 @@ class Trainer(object):
y_pred = np.sign(y_pred)
y_pred = np.where(y_pred == 0, 1, y_pred)
result = self._classification_score_metric(y_true, y_pred)
elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]:
elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier,
EnsembleSelectionForestClassifier]:
result = model.score(X, y_true)
return result
......@@ -142,7 +143,7 @@ class Trainer(object):
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
y_pred = model.predict_base_estimator(X)
result = self._base_regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier]:
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]:
y_pred = model.predict_base_estimator(X)
result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) == RandomForestClassifier:
......@@ -183,7 +184,8 @@ class Trainer(object):
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]:
if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
self._selected_trees = model.selected_trees
if len(self._selected_trees) > 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment