From 3a2ec5cbf3a657bb509a2b75bb44f93d52b8b1df Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Fri, 20 Dec 2019 09:38:34 +0100 Subject: [PATCH] Fix score func in trainer --- code/bolsonaro/trainer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index a1b5256..e1bc893 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -51,7 +51,8 @@ class Trainer(object): def train(self, model): """ - :param model: Object with + :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor, + OmpForestBinaryClassifier, OmpForestMulticlassClassifier. :return: """ @@ -72,32 +73,28 @@ class Trainer(object): self._end_time = time.time() def __score_func(self, model, X, y_true): - if type(model) == OmpForestRegressor: + if type(model) in [OmpForestRegressor, RandomForestRegressor]: y_pred = model.predict(X) result = mean_squared_error(y_true, y_pred) - - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: + elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]: y_pred = model.predict(X) result = accuracy_score(y_true, y_pred) - else: - y_pred = model.predict(X) - result = model.score(y_true, y_pred) - return result def __score_func_base(self, model, X, y_true): if type(model) == OmpForestRegressor: y_pred = model.predict_base_estimator(X) result = mean_squared_error(y_true, y_pred) - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: y_pred = model.predict_base_estimator(X) result = accuracy_score(y_true, y_pred) - - else: - y_pred = model.predict_base_estimator(X) - result = model.score(y_true, y_pred) + elif type(model) == RandomForestClassifier: + y_pred = model.predict(X) + result = accuracy_score(y_true, y_pred) + elif type(model) == RandomForestRegressor: + y_pred = model.predict(X) + result = mean_squared_error(y_true, y_pred) return result -- GitLab