diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index 63ef2803506a411979c3a59d6e3de45b854e85b0..d539f45314a244c410453bb84f726502c6ffe082 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -122,7 +122,7 @@ class SingleOmpForest(OmpForest): caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings)) if len(caught_warnings) > 0: - logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}') + self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}') def predict(self, X): """ diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index fe6096db0579eb41da42302bed31278e31ea6e76..7a22337b2fcf48b5181e4971836470e17d0f4f62 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -104,7 +104,7 @@ class OmpForestMulticlassClassifier(OmpForest): caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings)) if len(caught_warnings) > 0: - logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}') + self._logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}') self._dct_class_omp[class_label] = omp_class return self._dct_class_omp diff --git a/code/compute_results.py b/code/compute_results.py index 7d80b4c69308263530566704a8c6747033e99245..d77779e82e295b5e76c0347551c20b8ef258a546 100644 --- a/code/compute_results.py +++ b/code/compute_results.py @@ -157,6 +157,7 @@ if __name__ == "__main__": DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models' DEFAULT_PLOT_WEIGHT_DENSITY = False DEFAULT_WO_LOSS_PLOTS = False + DEFAULT_PLOT_PREDS_COHERENCE = False parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].') @@ -170,6 +171,7 @@ if __name__ == "__main__": parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.') parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.') parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.') + parser.add_argument('--plot_preds_coherence', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the coherence of the prediction trees.') args = parser.parse_args() if args.stage not in list(range(1, 6)): @@ -452,7 +454,7 @@ if __name__ == "__main__": omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes( args.models_dir, args.results_dir, int(args.experiment_ids[2])) #omp_with_params_without_weights - logger.info('Loading omp_with_params experiment scores...') + logger.info('Loading omp_with_params without weights experiment scores...') omp_with_params_without_weights_train_scores, omp_with_params_without_weights_dev_scores, omp_with_params_without_weights_test_scores, _, \ omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes( args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False)