diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index df6f024b0dd172ae62acdd866ce6248fb2117756..09862003daac7d12c984a8f4f3aca9c9f187052e 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier -from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor +from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier from bolsonaro.error_handling.logger_factory import LoggerFactory from bolsonaro.data.task import Task from . import LOG_PATH @@ -134,7 +134,8 @@ class Trainer(object): y_pred = np.sign(y_pred) y_pred = np.where(y_pred == 0, 1, y_pred) result = self._classification_score_metric(y_true, y_pred) - elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]: + elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier, + EnsembleSelectionForestClassifier]: result = model.score(X, y_true) return result @@ -142,7 +143,7 @@ class Trainer(object): if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]: y_pred = model.predict_base_estimator(X) result = self._base_regression_score_metric(y_true, y_pred) - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier]: + elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]: y_pred = model.predict_base_estimator(X) result = self._base_classification_score_metric(y_true, y_pred) elif type(model) == RandomForestClassifier: @@ -183,7 +184,8 @@ class Trainer(object): elif type(model) == OmpForestBinaryClassifier: model_weights = model._omp - if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]: + if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, + SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: self._selected_trees = model.selected_trees if len(self._selected_trees) > 0: