Skip to content
Snippets Groups Projects
Commit a6a476ce authored by Charly LAMOTHE's avatar Charly LAMOTHE
Browse files

Save all attributes of model_raw_results automatically

parent ec569ad5
No related branches found
No related tags found
1 merge request!3clean scripts
import pickle from bolsonaro.utils import save_obj_to_pickle, load_obj_from_pickle
import os import os
import datetime import datetime
...@@ -66,40 +67,11 @@ class ModelRawResults(object): ...@@ -66,40 +67,11 @@ class ModelRawResults(object):
def test_score_regressor(self): def test_score_regressor(self):
return self._test_score_regressor return self._test_score_regressor
@staticmethod def save(self, models_dir):
def save(models_dir, model, end_time, begin_time, dataset, logger): save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle',
output_file_path = models_dir + os.sep + 'model_raw_results.pickle' self.__dict__)
logger.debug('Saving trained model and raw results to {}'.format(output_file_path))
with open(output_file_path, 'wb') as output_file:
pickle.dump({
'forest': model.forest,
'weights': model.weights,
'training_time': end_time - begin_time,
'datetime': datetime.datetime.now(),
'train_score': model.score(dataset.X_train, dataset.y_train),
'dev_score': model.score(dataset.X_dev, dataset.y_dev),
'test_score': model.score(dataset.X_test, dataset.y_test),
'score_metric': model.default_score_metric,
'train_score_regressor': model.score_regressor(dataset.X_train, dataset.y_train),
'dev_score_regressor': model.score_regressor(dataset.X_dev, dataset.y_dev),
'test_score_regressor': model.score_regressor(dataset.X_test, dataset.y_test)
}, output_file)
@staticmethod @staticmethod
def load(models_dir): def load(models_dir):
model_file_path = models_dir + os.sep + 'model_raw_results.pickle' return load_obj_from_pickle(models_dir + os.sep + 'model_raw_results.pickle',
with open(model_file_path, 'rb') as input_file: ModelRawResults)
model_data = pickle.load(input_file)
return ModelRawResults(
forest=model_data['forest'],
weights=model_data['weights'],
training_time=model_data['training_time'],
datetime=model_data['datetime'],
train_score=model_data['train_score'],
dev_score=model_data['dev_score'],
test_score=model_data['test_score'],
score_metric=model_data['score_metric'],
train_score_regressor=model_data['train_score_regressor'],
dev_score_regressor=model_data['dev_score_regressor'],
test_score_regressor=model_data['test_score_regressor']
)
...@@ -3,6 +3,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory ...@@ -3,6 +3,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from . import LOG_PATH from . import LOG_PATH
import time import time
import datetime
class Trainer(object): class Trainer(object):
...@@ -25,4 +26,16 @@ class Trainer(object): ...@@ -25,4 +26,16 @@ class Trainer(object):
model.fit(X, y) model.fit(X, y)
end_time = time.time() end_time = time.time()
ModelRawResults.save(models_dir, model, end_time, begin_time, self._dataset, self._logger) ModelRawResults(
forest=model.forest,
weights=model.weights,
training_time=end_time - begin_time,
datetime=datetime.datetime.now(),
train_score=model.score(self._dataset.X_train, self._dataset.y_train),
dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev),
test_score=model.score(self._dataset.X_test, self._dataset.y_test),
score_metric=model.default_score_metric,
train_score_regressor=model.score_regressor(self._dataset.X_train, self._dataset.y_train),
dev_score_regressor=model.score_regressor(self._dataset.X_dev, self._dataset.y_dev),
test_score_regressor=model.score_regressor(self._dataset.X_test, self._dataset.y_test)
).save(models_dir)
import os import os
import json import json
import pickle
def resolve_experiment_id(models_dir): def resolve_experiment_id(models_dir):
...@@ -33,3 +34,15 @@ def load_obj_from_json(file_path, constructor): ...@@ -33,3 +34,15 @@ def load_obj_from_json(file_path, constructor):
with open(file_path, 'r') as input_file: with open(file_path, 'r') as input_file:
parameters = json.load(input_file) parameters = json.load(input_file)
return constructor(**parameters) return constructor(**parameters)
def save_obj_to_pickle(file_path, attributes_dict):
attributes = dict()
for key, value in attributes_dict.items():
attributes[key[1:]] = value
with open(file_path, 'wb') as output_file:
pickle.dump(attributes, output_file)
def load_obj_from_pickle(file_path, constructor):
with open(file_path, 'rb') as input_file:
parameters = pickle.load(input_file)
return constructor(**parameters)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment