From a6a476ce832a50eed57dd7b3e703d3b97f8e5501 Mon Sep 17 00:00:00 2001 From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr> Date: Tue, 5 Nov 2019 14:30:46 +0100 Subject: [PATCH] Save all attributes of model_raw_results automatically --- code/bolsonaro/models/model_raw_results.py | 44 ++++------------------ code/bolsonaro/trainer.py | 15 +++++++- code/bolsonaro/utils.py | 13 +++++++ 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py index 7b849d0..673cb0f 100644 --- a/code/bolsonaro/models/model_raw_results.py +++ b/code/bolsonaro/models/model_raw_results.py @@ -1,4 +1,5 @@ -import pickle +from bolsonaro.utils import save_obj_to_pickle, load_obj_from_pickle + import os import datetime @@ -66,40 +67,11 @@ class ModelRawResults(object): def test_score_regressor(self): return self._test_score_regressor - @staticmethod - def save(models_dir, model, end_time, begin_time, dataset, logger): - output_file_path = models_dir + os.sep + 'model_raw_results.pickle' - logger.debug('Saving trained model and raw results to {}'.format(output_file_path)) - with open(output_file_path, 'wb') as output_file: - pickle.dump({ - 'forest': model.forest, - 'weights': model.weights, - 'training_time': end_time - begin_time, - 'datetime': datetime.datetime.now(), - 'train_score': model.score(dataset.X_train, dataset.y_train), - 'dev_score': model.score(dataset.X_dev, dataset.y_dev), - 'test_score': model.score(dataset.X_test, dataset.y_test), - 'score_metric': model.default_score_metric, - 'train_score_regressor': model.score_regressor(dataset.X_train, dataset.y_train), - 'dev_score_regressor': model.score_regressor(dataset.X_dev, dataset.y_dev), - 'test_score_regressor': model.score_regressor(dataset.X_test, dataset.y_test) - }, output_file) + def save(self, models_dir): + save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle', + self.__dict__) @staticmethod - def load(models_dir): - model_file_path = models_dir + os.sep + 'model_raw_results.pickle' - with open(model_file_path, 'rb') as input_file: - model_data = pickle.load(input_file) - return ModelRawResults( - forest=model_data['forest'], - weights=model_data['weights'], - training_time=model_data['training_time'], - datetime=model_data['datetime'], - train_score=model_data['train_score'], - dev_score=model_data['dev_score'], - test_score=model_data['test_score'], - score_metric=model_data['score_metric'], - train_score_regressor=model_data['train_score_regressor'], - dev_score_regressor=model_data['dev_score_regressor'], - test_score_regressor=model_data['test_score_regressor'] - ) + def load(models_dir): + return load_obj_from_pickle(models_dir + os.sep + 'model_raw_results.pickle', + ModelRawResults) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 1120961..08d745c 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -3,6 +3,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory from . import LOG_PATH import time +import datetime class Trainer(object): @@ -25,4 +26,16 @@ class Trainer(object): model.fit(X, y) end_time = time.time() - ModelRawResults.save(models_dir, model, end_time, begin_time, self._dataset, self._logger) + ModelRawResults( + forest=model.forest, + weights=model.weights, + training_time=end_time - begin_time, + datetime=datetime.datetime.now(), + train_score=model.score(self._dataset.X_train, self._dataset.y_train), + dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev), + test_score=model.score(self._dataset.X_test, self._dataset.y_test), + score_metric=model.default_score_metric, + train_score_regressor=model.score_regressor(self._dataset.X_train, self._dataset.y_train), + dev_score_regressor=model.score_regressor(self._dataset.X_dev, self._dataset.y_dev), + test_score_regressor=model.score_regressor(self._dataset.X_test, self._dataset.y_test) + ).save(models_dir) diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py index 4186eef..a4d86e0 100644 --- a/code/bolsonaro/utils.py +++ b/code/bolsonaro/utils.py @@ -1,5 +1,6 @@ import os import json +import pickle def resolve_experiment_id(models_dir): @@ -33,3 +34,15 @@ def load_obj_from_json(file_path, constructor): with open(file_path, 'r') as input_file: parameters = json.load(input_file) return constructor(**parameters) + +def save_obj_to_pickle(file_path, attributes_dict): + attributes = dict() + for key, value in attributes_dict.items(): + attributes[key[1:]] = value + with open(file_path, 'wb') as output_file: + pickle.dump(attributes, output_file) + +def load_obj_from_pickle(file_path, constructor): + with open(file_path, 'rb') as input_file: + parameters = pickle.load(input_file) + return constructor(**parameters) -- GitLab