From 3a2ec5cbf3a657bb509a2b75bb44f93d52b8b1df Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 20 Dec 2019 09:38:34 +0100
Subject: [PATCH] Fix score func in trainer

---
 code/bolsonaro/trainer.py | 23 ++++++++++-------------
 1 file changed, 10 insertions(+), 13 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index a1b5256..e1bc893 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -51,7 +51,8 @@ class Trainer(object):
 
     def train(self, model):
         """
-        :param model: Object with
+        :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
+            OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
         :return:
         """
 
@@ -72,32 +73,28 @@ class Trainer(object):
         self._end_time = time.time()
 
     def __score_func(self, model, X, y_true):
-        if type(model) == OmpForestRegressor:
+        if type(model) in [OmpForestRegressor, RandomForestRegressor]:
             y_pred = model.predict(X)
             result = mean_squared_error(y_true, y_pred)
-
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
             y_pred = model.predict(X)
             result = accuracy_score(y_true, y_pred)
 
-        else:
-            y_pred = model.predict(X)
-            result = model.score(y_true, y_pred)
-
         return result
 
     def __score_func_base(self, model, X, y_true):
         if type(model) == OmpForestRegressor:
             y_pred = model.predict_base_estimator(X)
             result = mean_squared_error(y_true, y_pred)
-
         elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
             y_pred = model.predict_base_estimator(X)
             result = accuracy_score(y_true, y_pred)
-
-        else:
-            y_pred = model.predict_base_estimator(X)
-            result = model.score(y_true, y_pred)
+        elif type(model) == RandomForestClassifier:
+            y_pred = model.predict(X)
+            result = accuracy_score(y_true, y_pred)
+        elif type(model) == RandomForestRegressor:
+            y_pred = model.predict(X)
+            result = mean_squared_error(y_true, y_pred)
 
         return result
 
-- 
GitLab