Skip to content
Snippets Groups Projects
Commit b56e4254 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Hardcoded metric functions for comparison

parent 7330792e
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -8,8 +8,8 @@ class ModelRawResults(object): ...@@ -8,8 +8,8 @@ class ModelRawResults(object):
def __init__(self, model_object, training_time, def __init__(self, model_object, training_time,
datetime, train_score, dev_score, test_score, datetime, train_score, dev_score, test_score,
score_metric, train_score_regressor, dev_score_regressor, score_metric, train_score_base, dev_score_base,
test_score_regressor): test_score_base):
self._model_object = model_object self._model_object = model_object
self._training_time = training_time self._training_time = training_time
...@@ -18,9 +18,9 @@ class ModelRawResults(object): ...@@ -18,9 +18,9 @@ class ModelRawResults(object):
self._dev_score = dev_score self._dev_score = dev_score
self._test_score = test_score self._test_score = test_score
self._score_metric = score_metric self._score_metric = score_metric
self._train_score_regressor = train_score_regressor self._train_score_base = train_score_base
self._dev_score_regressor = dev_score_regressor self._dev_score_base = dev_score_base
self._test_score_regressor = test_score_regressor self._test_score_base = test_score_base
@property @property
def model_object(self): def model_object(self):
...@@ -51,16 +51,16 @@ class ModelRawResults(object): ...@@ -51,16 +51,16 @@ class ModelRawResults(object):
return self._score_metric return self._score_metric
@property @property
def train_score_regressor(self): def train_score_base(self):
return self._train_score_regressor return self._train_score_base
@property @property
def dev_score_regressor(self): def dev_score_base(self):
return self._dev_score_regressor return self._dev_score_base
@property @property
def test_score_regressor(self): def test_score_base(self):
return self._test_score_regressor return self._test_score_base
def save(self, models_dir): def save(self, models_dir):
save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle', save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle',
......
...@@ -17,6 +17,9 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): ...@@ -17,6 +17,9 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
def models_parameters(self): def models_parameters(self):
return self._models_parameters return self._models_parameters
def predict_base_estimator(self, X):
return self._base_forest_estimator.predict(X)
def score_base_estimator(self, X, y): def score_base_estimator(self, X, y):
return self._base_forest_estimator.score(X, y) return self._base_forest_estimator.score(X, y)
......
from bolsonaro.models.model_raw_results import ModelRawResults from bolsonaro.models.model_raw_results import ModelRawResults
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.error_handling.logger_factory import LoggerFactory from bolsonaro.error_handling.logger_factory import LoggerFactory
from . import LOG_PATH from . import LOG_PATH
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error, accuracy_score
import time import time
import datetime import datetime
import numpy as np import numpy as np
...@@ -68,32 +71,60 @@ class Trainer(object): ...@@ -68,32 +71,60 @@ class Trainer(object):
) )
self._end_time = time.time() self._end_time = time.time()
def __score_func(self, model, X, y_true):
if type(model) == OmpForestRegressor:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict(X)
result = model.score(y_true, y_pred)
return result
def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor:
y_pred = model.predict_base_estimator(X)
result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict_base_estimator(X)
result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict_base_estimator(X)
result = model.score(y_true, y_pred)
return result
def compute_results(self, model, models_dir): def compute_results(self, model, models_dir):
""" """
:param model: Object with :param model: Object with
:param models_dir: Where the results will be saved :param models_dir: Where the results will be saved
""" """
score_func = model.score if type(model) in [RandomForestRegressor, RandomForestClassifier] \
else model.score_base_estimator
results = ModelRawResults( results = ModelRawResults(
model_object=model, model_object=model,
training_time=self._end_time - self._begin_time, training_time=self._end_time - self._begin_time,
datetime=datetime.datetime.now(), datetime=datetime.datetime.now(),
train_score=model.score(self._dataset.X_train, self._dataset.y_train), train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev), dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev),
test_score=model.score(self._dataset.X_test, self._dataset.y_test), test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test),
train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
score_metric='mse' if type(model) in [RandomForestRegressor, RandomForestClassifier] \ score_metric='mse' if type(model) in [RandomForestRegressor, RandomForestClassifier] \
else model.DEFAULT_SCORE_METRIC, # TODO: resolve the used metric in a proper way else model.DEFAULT_SCORE_METRIC, # TODO: resolve the used metric in a proper way
train_score_regressor=score_func(self._dataset.X_train, self._dataset.y_train),
dev_score_regressor=score_func(self._dataset.X_dev, self._dataset.y_dev),
test_score_regressor=score_func(self._dataset.X_test, self._dataset.y_test)
) )
results.save(models_dir) results.save(models_dir)
self._logger.info("Base performance on test: {}".format(results.test_score_regressor)) self._logger.info("Base performance on test: {}".format(results.test_score_base))
self._logger.info("Performance on test: {}".format(results.test_score)) self._logger.info("Performance on test: {}".format(results.test_score))
self._logger.info("Base performance on train: {}".format(results.train_score_regressor)) self._logger.info("Base performance on train: {}".format(results.train_score_base))
self._logger.info("Performance on train: {}".format(results.train_score)) self._logger.info("Performance on train: {}".format(results.train_score))
self._logger.info("Base performance on dev: {}".format(results.dev_score_regressor)) self._logger.info("Base performance on dev: {}".format(results.dev_score_base))
self._logger.info("Performance on dev: {}".format(results.dev_score)) self._logger.info("Performance on dev: {}".format(results.dev_score))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment