diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py index 7b849d0466a17602d32519362db660b54b224880..673cb0fc65b7378e95c03b186d246cb70b384a07 100644 --- a/code/bolsonaro/models/model_raw_results.py +++ b/code/bolsonaro/models/model_raw_results.py @@ -1,4 +1,5 @@ -import pickle +from bolsonaro.utils import save_obj_to_pickle, load_obj_from_pickle + import os import datetime @@ -66,40 +67,11 @@ class ModelRawResults(object): def test_score_regressor(self): return self._test_score_regressor - @staticmethod - def save(models_dir, model, end_time, begin_time, dataset, logger): - output_file_path = models_dir + os.sep + 'model_raw_results.pickle' - logger.debug('Saving trained model and raw results to {}'.format(output_file_path)) - with open(output_file_path, 'wb') as output_file: - pickle.dump({ - 'forest': model.forest, - 'weights': model.weights, - 'training_time': end_time - begin_time, - 'datetime': datetime.datetime.now(), - 'train_score': model.score(dataset.X_train, dataset.y_train), - 'dev_score': model.score(dataset.X_dev, dataset.y_dev), - 'test_score': model.score(dataset.X_test, dataset.y_test), - 'score_metric': model.default_score_metric, - 'train_score_regressor': model.score_regressor(dataset.X_train, dataset.y_train), - 'dev_score_regressor': model.score_regressor(dataset.X_dev, dataset.y_dev), - 'test_score_regressor': model.score_regressor(dataset.X_test, dataset.y_test) - }, output_file) + def save(self, models_dir): + save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle', + self.__dict__) @staticmethod - def load(models_dir): - model_file_path = models_dir + os.sep + 'model_raw_results.pickle' - with open(model_file_path, 'rb') as input_file: - model_data = pickle.load(input_file) - return ModelRawResults( - forest=model_data['forest'], - weights=model_data['weights'], - training_time=model_data['training_time'], - datetime=model_data['datetime'], - train_score=model_data['train_score'], - dev_score=model_data['dev_score'], - test_score=model_data['test_score'], - score_metric=model_data['score_metric'], - train_score_regressor=model_data['train_score_regressor'], - dev_score_regressor=model_data['dev_score_regressor'], - test_score_regressor=model_data['test_score_regressor'] - ) + def load(models_dir): + return load_obj_from_pickle(models_dir + os.sep + 'model_raw_results.pickle', + ModelRawResults) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 1120961d2196e24e36f2a9bd441cc4845597d4d0..08d745c2425aed36312365b7a473629396108b25 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -3,6 +3,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory from . import LOG_PATH import time +import datetime class Trainer(object): @@ -25,4 +26,16 @@ class Trainer(object): model.fit(X, y) end_time = time.time() - ModelRawResults.save(models_dir, model, end_time, begin_time, self._dataset, self._logger) + ModelRawResults( + forest=model.forest, + weights=model.weights, + training_time=end_time - begin_time, + datetime=datetime.datetime.now(), + train_score=model.score(self._dataset.X_train, self._dataset.y_train), + dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev), + test_score=model.score(self._dataset.X_test, self._dataset.y_test), + score_metric=model.default_score_metric, + train_score_regressor=model.score_regressor(self._dataset.X_train, self._dataset.y_train), + dev_score_regressor=model.score_regressor(self._dataset.X_dev, self._dataset.y_dev), + test_score_regressor=model.score_regressor(self._dataset.X_test, self._dataset.y_test) + ).save(models_dir) diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py index 4186eefd1267b90c3396a05d2171d74684f5a611..a4d86e0ee22c7746af93f780e12cf9cd08c89ba5 100644 --- a/code/bolsonaro/utils.py +++ b/code/bolsonaro/utils.py @@ -1,5 +1,6 @@ import os import json +import pickle def resolve_experiment_id(models_dir): @@ -33,3 +34,15 @@ def load_obj_from_json(file_path, constructor): with open(file_path, 'r') as input_file: parameters = json.load(input_file) return constructor(**parameters) + +def save_obj_to_pickle(file_path, attributes_dict): + attributes = dict() + for key, value in attributes_dict.items(): + attributes[key[1:]] = value + with open(file_path, 'wb') as output_file: + pickle.dump(attributes, output_file) + +def load_obj_from_pickle(file_path, constructor): + with open(file_path, 'rb') as input_file: + parameters = pickle.load(input_file) + return constructor(**parameters)