Skip to content
Snippets Groups Projects
Commit 31d12659 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix few typos in new compute_hyperparams version.

parent 39d4540f
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -77,7 +77,6 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args): ...@@ -77,7 +77,6 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args):
# Move k best_parameters to a list of dict # Move k best_parameters to a list of dict
all_best_params = [opt_result['_best_parameters'] for opt_result in opt_results] all_best_params = [opt_result['_best_parameters'] for opt_result in opt_results]
print(all_best_params)
""" """
list of hyperparam dicts -> list of hyperparam list list of hyperparam dicts -> list of hyperparam list
...@@ -114,7 +113,7 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args): ...@@ -114,7 +113,7 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args):
break break
return { return {
'_scorer': opt_results[0]['_best_parameters'], '_scorer': opt_results[0]['_scorer'],
'_best_score_train': np.mean([opt_result['_best_score_train'] for opt_result in opt_results]), '_best_score_train': np.mean([opt_result['_best_score_train'] for opt_result in opt_results]),
'_best_score_test': np.mean([opt_result['_best_score_test'] for opt_result in opt_results]), '_best_score_test': np.mean([opt_result['_best_score_test'] for opt_result in opt_results]),
'_best_parameters': best_params, '_best_parameters': best_params,
...@@ -153,7 +152,7 @@ if __name__ == "__main__": ...@@ -153,7 +152,7 @@ if __name__ == "__main__":
logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.') logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')
# Seeds are either provided as parameters or generated at random # Seeds are either provided as parameters or generated at random
if args.use_variable_seed_number: if not args.use_variable_seed_number:
seeds = args.seeds if args.seeds is not None \ seeds = args.seeds if args.seeds is not None \
else [random.randint(begin_random_seed_range, end_random_seed_range) \ else [random.randint(begin_random_seed_range, end_random_seed_range) \
for i in range(args.random_seed_number)] for i in range(args.random_seed_number)]
...@@ -174,6 +173,6 @@ if __name__ == "__main__": ...@@ -174,6 +173,6 @@ if __name__ == "__main__":
for i in range(DatasetLoader.dataset_seed_numbers[dataset_name])] for i in range(DatasetLoader.dataset_seed_numbers[dataset_name])]
dict_results = compute_best_params_over_seeds(seeds, dataset_name, dict_results = compute_best_params_over_seeds(seeds, dataset_name,
DICT_PARAM_SPACE, args) DICT_PARAM_SPACE, args)
save_obj_to_json(os.path.join(dataset_dir, 'params.json'), dict_results) save_obj_to_json(os.path.join(dataset_dir, 'params.json'), dict_results)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment