From a6a476ce832a50eed57dd7b3e703d3b97f8e5501 Mon Sep 17 00:00:00 2001
From: Charly LAMOTHE <lamothe.c@intlocal.univ-amu.fr>
Date: Tue, 5 Nov 2019 14:30:46 +0100
Subject: [PATCH] Save all attributes of model_raw_results automatically

---
 code/bolsonaro/models/model_raw_results.py | 44 ++++------------------
 code/bolsonaro/trainer.py                  | 15 +++++++-
 code/bolsonaro/utils.py                    | 13 +++++++
 3 files changed, 35 insertions(+), 37 deletions(-)

diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py
index 7b849d0..673cb0f 100644
--- a/code/bolsonaro/models/model_raw_results.py
+++ b/code/bolsonaro/models/model_raw_results.py
@@ -1,4 +1,5 @@
-import pickle
+from bolsonaro.utils import save_obj_to_pickle, load_obj_from_pickle
+
 import os
 import datetime
 
@@ -66,40 +67,11 @@ class ModelRawResults(object):
     def test_score_regressor(self):
         return self._test_score_regressor
 
-    @staticmethod
-    def save(models_dir, model, end_time, begin_time, dataset, logger):
-        output_file_path = models_dir + os.sep + 'model_raw_results.pickle'
-        logger.debug('Saving trained model and raw results to {}'.format(output_file_path))
-        with open(output_file_path, 'wb') as output_file:
-            pickle.dump({
-                'forest': model.forest,
-                'weights': model.weights,
-                'training_time': end_time - begin_time,
-                'datetime': datetime.datetime.now(),
-                'train_score': model.score(dataset.X_train, dataset.y_train),
-                'dev_score': model.score(dataset.X_dev, dataset.y_dev),
-                'test_score': model.score(dataset.X_test, dataset.y_test),
-                'score_metric': model.default_score_metric,
-                'train_score_regressor': model.score_regressor(dataset.X_train, dataset.y_train),
-                'dev_score_regressor': model.score_regressor(dataset.X_dev, dataset.y_dev),
-                'test_score_regressor': model.score_regressor(dataset.X_test, dataset.y_test)
-            }, output_file)
+    def save(self, models_dir):
+        save_obj_to_pickle(models_dir + os.sep + 'model_raw_results.pickle',
+            self.__dict__)
 
     @staticmethod
-    def load(models_dir):
-        model_file_path = models_dir + os.sep + 'model_raw_results.pickle'
-        with open(model_file_path, 'rb') as input_file:
-            model_data = pickle.load(input_file)
-        return ModelRawResults(
-            forest=model_data['forest'],
-            weights=model_data['weights'],
-            training_time=model_data['training_time'],
-            datetime=model_data['datetime'],
-            train_score=model_data['train_score'],
-            dev_score=model_data['dev_score'],
-            test_score=model_data['test_score'],
-            score_metric=model_data['score_metric'],
-            train_score_regressor=model_data['train_score_regressor'],
-            dev_score_regressor=model_data['dev_score_regressor'],
-            test_score_regressor=model_data['test_score_regressor']
-        )
+    def load(models_dir):        
+        return load_obj_from_pickle(models_dir + os.sep + 'model_raw_results.pickle',
+            ModelRawResults)
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 1120961..08d745c 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -3,6 +3,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
 from . import LOG_PATH
 
 import time
+import datetime
 
 
 class Trainer(object):
@@ -25,4 +26,16 @@ class Trainer(object):
         model.fit(X, y)
         end_time = time.time()
 
-        ModelRawResults.save(models_dir, model, end_time, begin_time, self._dataset, self._logger)
+        ModelRawResults(
+            forest=model.forest,
+            weights=model.weights,
+            training_time=end_time - begin_time,
+            datetime=datetime.datetime.now(),
+            train_score=model.score(self._dataset.X_train, self._dataset.y_train),
+            dev_score=model.score(self._dataset.X_dev, self._dataset.y_dev),
+            test_score=model.score(self._dataset.X_test, self._dataset.y_test),
+            score_metric=model.default_score_metric,
+            train_score_regressor=model.score_regressor(self._dataset.X_train, self._dataset.y_train),
+            dev_score_regressor=model.score_regressor(self._dataset.X_dev, self._dataset.y_dev),
+            test_score_regressor=model.score_regressor(self._dataset.X_test, self._dataset.y_test)
+        ).save(models_dir)
diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py
index 4186eef..a4d86e0 100644
--- a/code/bolsonaro/utils.py
+++ b/code/bolsonaro/utils.py
@@ -1,5 +1,6 @@
 import os
 import json
+import pickle
 
 
 def resolve_experiment_id(models_dir):
@@ -33,3 +34,15 @@ def load_obj_from_json(file_path, constructor):
     with open(file_path, 'r') as input_file:
         parameters = json.load(input_file)
     return constructor(**parameters)
+
+def save_obj_to_pickle(file_path, attributes_dict):
+    attributes = dict()
+    for key, value in attributes_dict.items():
+        attributes[key[1:]] = value
+    with open(file_path, 'wb') as output_file:
+        pickle.dump(attributes, output_file)
+
+def load_obj_from_pickle(file_path, constructor):
+    with open(file_path, 'rb') as input_file:
+        parameters = pickle.load(input_file)
+    return constructor(**parameters)
-- 
GitLab