From 125817c1415062007f115d079e2b220bb80ec89b Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 6 Mar 2020 06:27:00 +0100
Subject: [PATCH] Add back kmeans method in the trainer (disapeared during a
 merge)

---
 code/bolsonaro/trainer.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 7070126..34b76c8 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -96,7 +96,7 @@ class Trainer(object):
         self._end_time = time.time()
 
     def __score_func(self, model, X, y_true, weights=True):
-        if type(model) in [OmpForestRegressor, RandomForestRegressor, SimilarityForestRegressor]:
+        if type(model) in [OmpForestRegressor, RandomForestRegressor]:
             if weights:
                 y_pred = model.predict(X)
             else:
@@ -109,12 +109,14 @@ class Trainer(object):
                 y_pred = model.predict_no_weights(X)
             if type(model) is OmpForestBinaryClassifier:
                 y_pred = np.sign(y_pred)
-                y_pred = np.where(y_pred==0, 1, y_pred)
+                y_pred = np.where(y_pred == 0, 1, y_pred)
             result = self._classification_score_metric(y_true, y_pred)
+        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor]:
+            result = model.score(X, y_true)
         return result
 
     def __score_func_base(self, model, X, y_true):
-        if type(model) == OmpForestRegressor:
+        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_regression_score_metric(y_true, y_pred)
         elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
@@ -123,7 +125,7 @@ class Trainer(object):
         elif type(model) == RandomForestClassifier:
             y_pred = model.predict(X)
             result = self._base_classification_score_metric(y_true, y_pred)
-        elif type(model) in [RandomForestRegressor, SimilarityForestRegressor]:
+        elif type(model) is RandomForestRegressor:
             y_pred = model.predict(X)
             result = self._base_regression_score_metric(y_true, y_pred)
         return result
-- 
GitLab