Skip to content
Snippets Groups Projects
Commit 30a57834 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Temp modif in stage4 result computation to unravel the difference between different subset usages

parent 4dd4fc23
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
......@@ -328,7 +328,7 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different training subsets'.format(args.dataset_name))"""
elif args.stage == 4:
if len(args.experiment_ids) != 3:
if len(args.experiment_ids) != 5:
raise ValueError('In the case of stage 4, the number of specified experiment ids must be 3.')
# Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary
......@@ -351,6 +351,18 @@ if __name__ == "__main__":
omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, args.experiment_ids[2])
# base_with_params
logger.info('Loading base_with_params experiment scores 2...')
_, _, base_with_params_test_scores_2, \
_ = \
extract_scores_across_seeds_and_forest_size(args.models_dir, args.results_dir, args.experiment_ids[3],
extracted_forest_sizes_number)
# random_with_params
logger.info('Loading random_with_params experiment scores 2...')
_, _, random_with_params_test_scores_2, \
_, _ = \
extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[4])
# Sanity check on the metrics retreived
if not (base_with_params_experiment_score_metric == random_with_params_experiment_score_metric
== omp_with_params_experiment_score_metric):
......@@ -362,9 +374,10 @@ if __name__ == "__main__":
Plotter.plot_stage2_losses(
file_path=output_path + os.sep + 'losses.png',
all_experiment_scores=[base_with_params_test_scores, random_with_params_test_scores,
all_experiment_scores=[base_with_params_test_scores, base_with_params_test_scores_2, random_with_params_test_scores,
random_with_params_test_scores_2,
omp_with_params_test_scores],
all_labels=['base', 'random', 'omp'],
all_labels=['base_train-dev', 'base', 'random_train-dev', 'random', 'omp'],
x_value=with_params_extracted_forest_sizes,
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment