From a75125c5cf2fffde28acffd38ac23d5c2e7227a2 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 13 Mar 2020 11:31:58 +0100
Subject: [PATCH] Fix logging in case of catching omp warning.

---
 code/bolsonaro/models/omp_forest.py            | 2 +-
 code/bolsonaro/models/omp_forest_classifier.py | 2 +-
 code/compute_results.py                        | 4 +++-
 3 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index 63ef280..d539f45 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -122,7 +122,7 @@ class SingleOmpForest(OmpForest):
             caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
 
             if len(caught_warnings) > 0:
-                logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
+                self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
 
     def predict(self, X):
         """
diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index fe6096d..7a22337 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -104,7 +104,7 @@ class OmpForestMulticlassClassifier(OmpForest):
                 caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
 
                 if len(caught_warnings) > 0:
-                    logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
+                    self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
 
             self._dct_class_omp[class_label] = omp_class
         return self._dct_class_omp
diff --git a/code/compute_results.py b/code/compute_results.py
index 7d80b4c..d77779e 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -157,6 +157,7 @@ if __name__ == "__main__":
     DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
     DEFAULT_PLOT_WEIGHT_DENSITY = False
     DEFAULT_WO_LOSS_PLOTS = False
+    DEFAULT_PLOT_PREDS_COHERENCE = False
 
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
@@ -170,6 +171,7 @@ if __name__ == "__main__":
     parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
     parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
     parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.')
+    parser.add_argument('--plot_preds_coherence', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the coherence of the prediction trees.')
     args = parser.parse_args()
 
     if args.stage not in list(range(1, 6)):
@@ -452,7 +454,7 @@ if __name__ == "__main__":
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, int(args.experiment_ids[2]))
         #omp_with_params_without_weights
-        logger.info('Loading omp_with_params experiment scores...')
+        logger.info('Loading omp_with_params without weights experiment scores...')
         omp_with_params_without_weights_train_scores, omp_with_params_without_weights_dev_scores, omp_with_params_without_weights_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False)
-- 
GitLab