From 125817c1415062007f115d079e2b220bb80ec89b Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Fri, 6 Mar 2020 06:27:00 +0100 Subject: [PATCH] Add back kmeans method in the trainer (disapeared during a merge) --- code/bolsonaro/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 7070126..34b76c8 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -96,7 +96,7 @@ class Trainer(object): self._end_time = time.time() def __score_func(self, model, X, y_true, weights=True): - if type(model) in [OmpForestRegressor, RandomForestRegressor, SimilarityForestRegressor]: + if type(model) in [OmpForestRegressor, RandomForestRegressor]: if weights: y_pred = model.predict(X) else: @@ -109,12 +109,14 @@ class Trainer(object): y_pred = model.predict_no_weights(X) if type(model) is OmpForestBinaryClassifier: y_pred = np.sign(y_pred) - y_pred = np.where(y_pred==0, 1, y_pred) + y_pred = np.where(y_pred == 0, 1, y_pred) result = self._classification_score_metric(y_true, y_pred) + elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor]: + result = model.score(X, y_true) return result def __score_func_base(self, model, X, y_true): - if type(model) == OmpForestRegressor: + if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor]: y_pred = model.predict_base_estimator(X) result = self._base_regression_score_metric(y_true, y_pred) elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: @@ -123,7 +125,7 @@ class Trainer(object): elif type(model) == RandomForestClassifier: y_pred = model.predict(X) result = self._base_classification_score_metric(y_true, y_pred) - elif type(model) in [RandomForestRegressor, SimilarityForestRegressor]: + elif type(model) is RandomForestRegressor: y_pred = model.predict(X) result = self._base_regression_score_metric(y_true, y_pred) return result -- GitLab