From a98fd9321ba7bd10b227e1dcfca965ba1679cbff Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Wed, 25 Mar 2020 00:52:41 +0100
Subject: [PATCH] Fix missing EnsembleSelectionForestClassifier in sanity
 checks of trainer

---
 code/bolsonaro/trainer.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index df6f024..0986200 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
-from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
+from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
 from bolsonaro.error_handling.logger_factory import LoggerFactory
 from bolsonaro.data.task import Task
 from . import LOG_PATH
@@ -134,7 +134,8 @@ class Trainer(object):
                 y_pred = np.sign(y_pred)
                 y_pred = np.where(y_pred == 0, 1, y_pred)
             result = self._classification_score_metric(y_true, y_pred)
-        elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]:
+        elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier,
+            EnsembleSelectionForestClassifier]:
             result = model.score(X, y_true)
         return result
 
@@ -142,7 +143,7 @@ class Trainer(object):
         if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_regression_score_metric(y_true, y_pred)
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier]:
+        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_classification_score_metric(y_true, y_pred)
         elif type(model) == RandomForestClassifier:
@@ -183,7 +184,8 @@ class Trainer(object):
         elif type(model) == OmpForestBinaryClassifier:
             model_weights = model._omp
 
-        if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]:
+        if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
+            SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
             self._selected_trees = model.selected_trees
 
         if len(self._selected_trees) > 0:
-- 
GitLab