From 881106aee61f6e431ae764d01f064bb4418427fa Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 28 Feb 2020 18:12:28 +0100
Subject: [PATCH] Add weights saving. TODO: density plots at least

---
 code/bolsonaro/trainer.py |  8 +++++---
 code/compute_results.py   | 19 +++++++++++++------
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 5920fa2..389ab9d 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -128,10 +128,12 @@ class Trainer(object):
         """
 
         model_weights = ''
-        if type(model) == RandomForestRegressor:
-            model_weights = model.coef_
-        elif type(model) == OmpForestRegressor:
+        if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
             model_weights = model._omp.coef_
+        elif type(model) == OmpForestMulticlassClassifier:
+            model_weights = model._dct_class_omp
+        elif type(model) == OmpForestBinaryClassifier:
+            model_weights = model._omp
 
         results = ModelRawResults(
             model_weights=model_weights,
diff --git a/code/compute_results.py b/code/compute_results.py
index bad281c..4fce327 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -28,13 +28,12 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
     experiment_train_scores = dict()
     experiment_dev_scores = dict()
     experiment_test_scores = dict()
+    experiment_weights = dict()
     all_extracted_forest_sizes = list()
 
     # Used to check if all losses were computed using the same metric (it should be the case)
     experiment_score_metrics = list()
 
-    all_weights = list()
-
     # For each seed results stored in models/{experiment_id}/seeds
     seeds = os.listdir(experiment_seed_root_path)
     seeds.sort(key=int)
@@ -46,6 +45,7 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
         experiment_train_scores[seed] = list()
         experiment_dev_scores[seed] = list()
         experiment_test_scores[seed] = list()
+        experiment_weights[seed] = list()
 
         # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
         extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
@@ -62,6 +62,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
             experiment_test_scores[seed].append(model_raw_results.test_score)
             # Save the metric
             experiment_score_metrics.append(model_raw_results.score_metric)
+            # Save the weights
+            #experiment_weights[seed].append(model_raw_results.model_weights)
 
     # Sanity checks
     if len(set(experiment_score_metrics)) > 1:
@@ -69,7 +71,8 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
     if len(set([sum(extracted_forest_sizes) for extracted_forest_sizes in all_extracted_forest_sizes])) != 1:
         raise ValueError("The extracted forest sizes aren't the sames across seeds.")
 
-    return experiment_train_scores, experiment_dev_scores, experiment_test_scores, all_extracted_forest_sizes[0], experiment_score_metrics[0]
+    return experiment_train_scores, experiment_dev_scores, experiment_test_scores, \
+        all_extracted_forest_sizes[0], experiment_score_metrics[0]#, experiment_weights
 
 def extract_scores_across_seeds_and_forest_size(models_dir, results_dir, experiment_id, extracted_forest_sizes_number):
     experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
@@ -228,8 +231,6 @@ if __name__ == "__main__":
             ylabel=experiments_score_metric,
             title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
         )
-
-        Plotter.plot_weight_density()
     elif args.stage == 2:
         if len(args.experiment_ids) != 4:
             raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
@@ -353,6 +354,9 @@ if __name__ == "__main__":
             extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
         # omp_with_params
         logger.info('Loading omp_with_params experiment scores...')
+        """omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
+            omp_with_params_experiment_score_metric, experiment_weights = extract_scores_across_seeds_and_extracted_forest_sizes(
+                args.models_dir, args.results_dir, args.experiment_ids[2])"""
         omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, args.experiment_ids[2])
@@ -375,7 +379,7 @@ if __name__ == "__main__":
             raise ValueError('Score metrics of all experiments must be the same.')
         experiments_score_metric = base_with_params_experiment_score_metric
 
-        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4')
+        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix')
         pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
 
         Plotter.plot_stage2_losses(
@@ -386,6 +390,9 @@ if __name__ == "__main__":
             xlabel='Number of trees extracted',
             ylabel=experiments_score_metric,
             title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
+
+        # experiment_weights
+        #Plotter.weight_density(experiment_weights, output_path + os.sep + 'weight_density.png')
     else:
         raise ValueError('This stage number is not supported yet, but it will be!')
 
-- 
GitLab