diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py index df8b2ec0b10704a8a8c397b9012298e8b901e14b..8c5f9c7b4f49e157dc27e5403f84134bdb12f1f3 100644 --- a/code/bolsonaro/models/model_raw_results.py +++ b/code/bolsonaro/models/model_raw_results.py @@ -8,8 +8,8 @@ class ModelRawResults(object): def __init__(self, model_object, training_time, datetime, train_score, dev_score, test_score, - score_metric, train_score_regressor, dev_score_regressor, - test_score_regressor): + score_metric, train_score_base, dev_score_base, + test_score_base): self._model_object = model_object self._training_time = training_time @@ -18,9 +18,9 @@ class ModelRawResults(object): self._dev_score = dev_score self._test_score = test_score self._score_metric = score_metric - self._train_score_regressor = train_score_regressor - self._dev_score_regressor = dev_score_regressor - self._test_score_regressor = test_score_regressor + self._train_score_base = train_score_base + self._dev_score_base = dev_score_base + self._test_score_base = test_score_base @property def model_object(self): @@ -51,16 +51,16 @@ class ModelRawResults(object): return self._score_metric @property - def train_score_regressor(self): - return self._train_score_regressor + def train_score_base(self): + return self._train_score_base @property - def dev_score_regressor(self): - return self._dev_score_regressor + def dev_score_base(self): + return self._dev_score_base @property - def test_score_regressor(self): - return self._test_score_regressor + def test_score_base(self): + return self._test_score_base def save(self, models_dir): save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle', diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index ea6bf9108712043ff964ca39c1ec7728ddd26f20..16c3e1c9919a719ecedf4f2cd1d18ae4ee59fd13 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -17,6 +17,9 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): def models_parameters(self): return self._models_parameters + def predict_base_estimator(self, X): + return self._base_forest_estimator.predict(X) + def score_base_estimator(self, X, y): return self._base_forest_estimator.score(X, y) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 549df6e7aceb93706b8d5902f994501aabdde015..a1b5256a20a6c36f81152c8545b0b092b2c10f53 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -1,8 +1,11 @@ from bolsonaro.models.model_raw_results import ModelRawResults +from bolsonaro.models.omp_forest_regressor import OmpForestRegressor +from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier from bolsonaro.error_handling.logger_factory import LoggerFactory from . import LOG_PATH from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +from sklearn.metrics import mean_squared_error, accuracy_score import time import datetime import numpy as np @@ -68,32 +71,60 @@ class Trainer(object): ) self._end_time = time.time() + def __score_func(self, model, X, y_true): + if type(model) == OmpForestRegressor: + y_pred = model.predict(X) + result = mean_squared_error(y_true, y_pred) + + elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: + y_pred = model.predict(X) + result = accuracy_score(y_true, y_pred) + + else: + y_pred = model.predict(X) + result = model.score(y_true, y_pred) + + return result + + def __score_func_base(self, model, X, y_true): + if type(model) == OmpForestRegressor: + y_pred = model.predict_base_estimator(X) + result = mean_squared_error(y_true, y_pred) + + elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: + y_pred = model.predict_base_estimator(X) + result = accuracy_score(y_true, y_pred) + + else: + y_pred = model.predict_base_estimator(X) + result = model.score(y_true, y_pred) + + return result + def compute_results(self, model, models_dir): """ :param model: Object with :param models_dir: Where the results will be saved """ - score_func = model.score if type(model) in [RandomForestRegressor, RandomForestClassifier] \ - else model.score_base_estimator results = ModelRawResults( model_object=model, training_time=self._end_time - self._begin_time, datetime=datetime.datetime.now(), - train_score=model.score(self._dataset.X_train, self._dataset.y_train), - dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev), - test_score=model.score(self._dataset.X_test, self._dataset.y_test), + train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train), + dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev), + test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test), + train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train), + dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev), + test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test), score_metric='mse' if type(model) in [RandomForestRegressor, RandomForestClassifier] \ - else model.DEFAULT_SCORE_METRIC, # TODO: resolve the used metric in a proper way - train_score_regressor=score_func(self._dataset.X_train, self._dataset.y_train), - dev_score_regressor=score_func(self._dataset.X_dev, self._dataset.y_dev), - test_score_regressor=score_func(self._dataset.X_test, self._dataset.y_test) + else model.DEFAULT_SCORE_METRIC, # TODO: resolve the used metric in a proper way ) results.save(models_dir) - self._logger.info("Base performance on test: {}".format(results.test_score_regressor)) + self._logger.info("Base performance on test: {}".format(results.test_score_base)) self._logger.info("Performance on test: {}".format(results.test_score)) - self._logger.info("Base performance on train: {}".format(results.train_score_regressor)) + self._logger.info("Base performance on train: {}".format(results.train_score_base)) self._logger.info("Performance on train: {}".format(results.train_score)) - self._logger.info("Base performance on dev: {}".format(results.dev_score_regressor)) + self._logger.info("Base performance on dev: {}".format(results.dev_score_base)) self._logger.info("Performance on dev: {}".format(results.dev_score))