Skip to content
Snippets Groups Projects
Commit 17d3addc authored by Charly Lamothe's avatar Charly Lamothe
Browse files

- The trainer can now specify which metrics to use;

- Remove useless getter in Dataset class.
parent 628fba24
Branches
No related tags found
1 merge request!9Resolve "Experiment pipeline"
......@@ -14,10 +14,6 @@ class Dataset(object):
def task(self):
return self._task
@property
def dataset_parameters(self):
return self._dataset_parameters
@property
def X_train(self):
return self._X_train
......
......@@ -2,6 +2,7 @@ from bolsonaro.models.model_raw_results import ModelRawResults
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.error_handling.logger_factory import LoggerFactory
from bolsonaro.data.task import Task
from . import LOG_PATH
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
......@@ -16,13 +17,30 @@ class Trainer(object):
Class capable of fitting any model object to some prepared data then evaluate and save results through the `train` method.
"""
def __init__(self, dataset):
def __init__(self, dataset, regression_score_metric=mean_squared_error, classification_score_metric=accuracy_score,
base_regression_score_metric=mean_squared_error, base_classification_score_metric=accuracy_score):
"""
:param dataset: Object with X_train, y_train, X_dev, y_dev, X_test and Y_test attributes
"""
self._dataset = dataset
self._logger = LoggerFactory.create(LOG_PATH, __name__)
self._regression_score_metric = regression_score_metric
self._classification_score_metric = classification_score_metric
self._base_regression_score_metric = base_regression_score_metric
self._base_classification_score_metric = base_classification_score_metric
self._score_metric_name = regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
else classification_score_metric.__name__
self._base_score_metric_name = base_regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
else base_classification_score_metric.__name__
@property
def score_metric_name(self):
return self._score_metric_name
@property
def base_score_metric_name(self):
return self._base_score_metric_name
def init(self, model):
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
......@@ -75,27 +93,25 @@ class Trainer(object):
def __score_func(self, model, X, y_true):
if type(model) in [OmpForestRegressor, RandomForestRegressor]:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
result = self._regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred)
result = self._classification_score_metric(y_true, y_pred)
return result
def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor:
y_pred = model.predict_base_estimator(X)
result = mean_squared_error(y_true, y_pred)
result = self._base_regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict_base_estimator(X)
result = accuracy_score(y_true, y_pred)
result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) == RandomForestClassifier:
y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred)
result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) == RandomForestRegressor:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
result = self._base_regression_score_metric(y_true, y_pred)
return result
def compute_results(self, model, models_dir):
......@@ -113,8 +129,8 @@ class Trainer(object):
train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
score_metric='mse' if type(model) in [RandomForestRegressor, RandomForestClassifier] \
else model.DEFAULT_SCORE_METRIC, # TODO: resolve the used metric in a proper way
score_metric=self._score_metric_name,
base_score_metric=self._base_score_metric_name
)
results.save(models_dir)
self._logger.info("Base performance on test: {}".format(results.test_score_base))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment