diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index a1b5256a20a6c36f81152c8545b0b092b2c10f53..e1bc893dca0dae03c3b24a3265868547004b3a3e 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -51,7 +51,8 @@ class Trainer(object): def train(self, model): """ - :param model: Object with + :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor, + OmpForestBinaryClassifier, OmpForestMulticlassClassifier. :return: """ @@ -72,32 +73,28 @@ class Trainer(object): self._end_time = time.time() def __score_func(self, model, X, y_true): - if type(model) == OmpForestRegressor: + if type(model) in [OmpForestRegressor, RandomForestRegressor]: y_pred = model.predict(X) result = mean_squared_error(y_true, y_pred) - - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: + elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]: y_pred = model.predict(X) result = accuracy_score(y_true, y_pred) - else: - y_pred = model.predict(X) - result = model.score(y_true, y_pred) - return result def __score_func_base(self, model, X, y_true): if type(model) == OmpForestRegressor: y_pred = model.predict_base_estimator(X) result = mean_squared_error(y_true, y_pred) - elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: y_pred = model.predict_base_estimator(X) result = accuracy_score(y_true, y_pred) - - else: - y_pred = model.predict_base_estimator(X) - result = model.score(y_true, y_pred) + elif type(model) == RandomForestClassifier: + y_pred = model.predict(X) + result = accuracy_score(y_true, y_pred) + elif type(model) == RandomForestRegressor: + y_pred = model.predict(X) + result = mean_squared_error(y_true, y_pred) return result