From 31d12659ce085f9433b79cc83dceae8d125eb841 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Sun, 1 Dec 2019 21:36:43 +0100
Subject: [PATCH] Fix few typos in new compute_hyperparams version.

---
 code/compute_hyperparameters.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/code/compute_hyperparameters.py b/code/compute_hyperparameters.py
index e6068e0..a96ef68 100644
--- a/code/compute_hyperparameters.py
+++ b/code/compute_hyperparameters.py
@@ -77,7 +77,6 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args):
 
     # Move k best_parameters to a list of dict
     all_best_params = [opt_result['_best_parameters'] for opt_result in opt_results]
-    print(all_best_params)
 
     """
     list of hyperparam dicts -> list of hyperparam list
@@ -114,7 +113,7 @@ def compute_best_params_over_seeds(seeds, dataset_name, param_space, args):
             break
 
     return {
-        '_scorer': opt_results[0]['_best_parameters'],
+        '_scorer': opt_results[0]['_scorer'],
         '_best_score_train': np.mean([opt_result['_best_score_train'] for opt_result in opt_results]),
         '_best_score_test': np.mean([opt_result['_best_score_test'] for opt_result in opt_results]),
         '_best_parameters': best_params,
@@ -153,7 +152,7 @@ if __name__ == "__main__":
         logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')    
 
     # Seeds are either provided as parameters or generated at random
-    if args.use_variable_seed_number:
+    if not args.use_variable_seed_number:
         seeds = args.seeds if args.seeds is not None \
             else [random.randint(begin_random_seed_range, end_random_seed_range) \
             for i in range(args.random_seed_number)]
@@ -174,6 +173,6 @@ if __name__ == "__main__":
                 for i in range(DatasetLoader.dataset_seed_numbers[dataset_name])]
 
         dict_results = compute_best_params_over_seeds(seeds, dataset_name,
-            DICT_PARAM_SPACE, args)        
+            DICT_PARAM_SPACE, args)
 
         save_obj_to_json(os.path.join(dataset_dir, 'params.json'), dict_results)
-- 
GitLab