diff --git a/config_files/config.yml b/config_files/config.yml index eeaa3a6b403d0d1d697781caaf1fa4fe962386e2..b0f051bd79e70cc643d99b5eac8f0ab4293be7d9 100644 --- a/config_files/config.yml +++ b/config_files/config.yml @@ -1,62 +1,64 @@ # The base configuration of the benchmark -Base : - # Enable logging - log: True - # The name of each dataset in the directory on which the benchmark should be run - name: ["plausible"] - # A label for the resul directory - label: "_" - # The type of dataset, currently supported ".hdf5", and ".csv" - type: ".hdf5" - # The views to use in the banchmark, an empty value will result in using all the views - views: - # The path to the directory where the datasets are stored - pathf: "../data/" - # The niceness of the processes, useful to lower their priority - nice: 0 - # The random state of the benchmark, useful for reproducibility - random_state: 42 - # The number of parallel computing threads - nb_cores: 1 - # Used to run the benchmark on the full dataset - full: False - # Used to be able to run more than one benchmark per minute - debug: False - # To add noise to the data, will add gaussian noise with noise_std - add_noise: False - noise_std: 0.0 - # The directory in which the results will be stored - res_dir: "../results/" - -# All the classification-realted configuration options -Classification: - # If the dataset is multiclass, will use this multiclass-to-biclass method - multiclass_method: "oneVersusOne" - # The ratio number of test exmaples/number of train examples - split: 0.8 - # The nubmer of folds in the cross validation process when hyper-paramter optimization is performed - nb_folds: 2 - # The number of classes to select in the dataset - nb_class: 2 - # The name of the classes to select in the dataset - classes: - # The type of algorithms to run during the benchmark (monoview and/or multiview) - type: ["monoview","multiview"] - # The name of the monoview algorithms to run, ["all"] to run all the available classifiers - algos_monoview: ["all"] - # The names of the multiview algorithms to run, ["all"] to run all the available classifiers - algos_multiview: ["all"] - # The number of times the benchamrk is repeated with different train/test - # split, to have more statistically significant results - stats_iter: 1 - # The metrics that will be use din the result analysis - metrics: ["accuracy_score", "f1_score"] - # The metric that will be used in the hyper-parameter optimization process - metric_princ: "f1_score" - # The type of hyper-parameter optimization method - hps_type: "randomized_search" - # The number of iteration in the hyper-parameter optimization process - hps_iter: 2 + +# Enable logging +log: True +# The name of each dataset in the directory on which the benchmark should be run +name: ["plausible"] +# A label for the resul directory +label: "_" +# The type of dataset, currently supported ".hdf5", and ".csv" +file_type: ".hdf5" +# The views to use in the banchmark, an empty value will result in using all the views +views: +# The path to the directory where the datasets are stored +pathf: "../data/" +# The niceness of the processes, useful to lower their priority +nice: 0 +# The random state of the benchmark, useful for reproducibility +random_state: 42 +# The number of parallel computing threads +nb_cores: 1 +# Used to run the benchmark on the full dataset +full: False +# Used to be able to run more than one benchmark per minute +debug: False +# To add noise to the data, will add gaussian noise with noise_std +add_noise: False +noise_std: 0.0 +# The directory in which the results will be stored +res_dir: "../results/" +# If an error occurs in a classifier, if track_tracebacks is set to True, the +# benchmark saves the traceback and continues, if it is set to False, it will +# stop the benchmark and raise the error +track_tracebacks: True + +# If the dataset is multiclass, will use this multiclass-to-biclass method +multiclass_method: "oneVersusOne" +# The ratio number of test exmaples/number of train examples +split: 0.8 +# The nubmer of folds in the cross validation process when hyper-paramter optimization is performed +nb_folds: 2 +# The number of classes to select in the dataset +nb_class: 2 +# The name of the classes to select in the dataset +classes: +# The type of algorithms to run during the benchmark (monoview and/or multiview) +type: ["monoview","multiview"] +# The name of the monoview algorithms to run, ["all"] to run all the available classifiers +algos_monoview: ["all"] +# The names of the multiview algorithms to run, ["all"] to run all the available classifiers +algos_multiview: ["all"] +# The number of times the benchamrk is repeated with different train/test +# split, to have more statistically significant results +stats_iter: 1 +# The metrics that will be use din the result analysis +metrics: ["accuracy_score", "f1_score"] +# The metric that will be used in the hyper-parameter optimization process +metric_princ: "f1_score" +# The type of hyper-parameter optimization method +hps_type: "randomized_search" +# The number of iteration in the hyper-parameter optimization process +hps_iter: 2 # The following arguments are classifier-specific, and are documented in each diff --git a/config_files/config_test.yml b/config_files/config_test.yml index 6f5c1407ee3977856c4c0ff97022877651d00cdc..9118a96bf0f260267f131e0936b839d8a4ac78cc 100644 --- a/config_files/config_test.yml +++ b/config_files/config_test.yml @@ -1,36 +1,34 @@ # The base configuration of the benchmark -Base : - log: True - name: ["plausible",] - label: "_" - type: ".hdf5" - views: - pathf: "../data/" - nice: 0 - random_state: 42 - nb_cores: 1 - full: False - debug: True - add_noise: False - noise_std: 0.0 - res_dir: "../results/" - track_tracebacks: False +log: True +name: ["plausible",] +label: "_" +file_type: ".hdf5" +views: +pathf: "../data/" +nice: 0 +random_state: 42 +nb_cores: 1 +full: False +debug: True +add_noise: False +noise_std: 0.0 +res_dir: "../results/" +track_tracebacks: False # All the classification-realted configuration options -Classification: - multiclass_method: "oneVersusOne" - split: 0.49 - nb_folds: 2 - nb_class: 3 - classes: - type: ["multiview",] - algos_monoview: ["all" ] - algos_multiview: ["svm_jumbo_fusion",] - stats_iter: 2 - metrics: ["accuracy_score", "f1_score"] - metric_princ: "f1_score" - hps_type: "randomized_search" - hps_iter: 1 +multiclass_method: "oneVersusOne" +split: 0.49 +nb_folds: 2 +nb_class: 3 +classes: +type: ["multiview",] +algos_monoview: ["all" ] +algos_multiview: ["svm_jumbo_fusion",] +stats_iter: 2 +metrics: ["accuracy_score", "f1_score"] +metric_princ: "f1_score" +hps_type: "randomized_search" +hps_iter: 1 ###################################### diff --git a/multiview_platform/mono_multi_view_classifiers/exec_classif.py b/multiview_platform/mono_multi_view_classifiers/exec_classif.py index 2c19d36709043e2f8b45bb0a86c1b8a083afd3ed..4ed3f665313a63a051756730564bbc334e2cccdc 100644 --- a/multiview_platform/mono_multi_view_classifiers/exec_classif.py +++ b/multiview_platform/mono_multi_view_classifiers/exec_classif.py @@ -552,24 +552,24 @@ def benchmark_init(directory, classification_indices, labels, labels_dictionary, # # logging.debug("Start:\t monoview benchmark") # results_monoview += [ -# exec_monoview_multicore(directory, args["Base"]["name"], labels_names, +# exec_monoview_multicore(directory, args["name"], labels_names, # classification_indices, k_folds, -# core_index, args["Base"]["type"], args["Base"]["pathf"], random_state, +# core_index, args["file_type"], args["pathf"], random_state, # labels, # hyper_param_search=hyper_param_search, # metrics=metrics, -# n_iter=args["Classification"]["hps_iter"], **argument) +# n_iter=args["hps_iter"], **argument) # for argument in argument_dictionaries["Monoview"]] # logging.debug("Done:\t monoview benchmark") # # # logging.debug("Start:\t multiview benchmark") # results_multiview = [ -# exec_multiview_multicore(directory, core_index, args["Base"]["name"], -# classification_indices, k_folds, args["Base"]["type"], -# args["Base"]["pathf"], labels_dictionary, random_state, +# exec_multiview_multicore(directory, core_index, args["name"], +# classification_indices, k_folds, args["file_type"], +# args["pathf"], labels_dictionary, random_state, # labels, hyper_param_search=hyper_param_search, -# metrics=metrics, n_iter=args["Classification"]["hps_iter"], +# metrics=metrics, n_iter=args["hps_iter"], # **arguments) # for arguments in argument_dictionaries["multiview"]] # logging.debug("Done:\t multiview benchmark") @@ -599,13 +599,13 @@ def benchmark_init(directory, classification_indices, labels, labels_dictionary, # nb_multicore_to_do = int(math.ceil(float(nb_experiments) / nb_cores)) # for step_index in range(nb_multicore_to_do): # results_monoview += (Parallel(n_jobs=nb_cores)( -# delayed(exec_monoview_multicore)(directory, args["Base"]["name"], labels_names, +# delayed(exec_monoview_multicore)(directory, args["name"], labels_names, # classification_indices, k_folds, -# core_index, args["Base"]["type"], args["Base"]["pathf"], +# core_index, args["file_type"], args["pathf"], # random_state, labels, # hyper_param_search=hyper_param_search, # metrics=metrics, -# n_iter=args["Classification"]["hps_iter"], +# n_iter=args["hps_iter"], # **argument_dictionaries["monoview"][ # core_index + step_index * nb_cores]) # for core_index in @@ -627,14 +627,14 @@ def benchmark_init(directory, classification_indices, labels, labels_dictionary, # nb_multicore_to_do = int(math.ceil(float(nb_experiments) / nb_cores)) # for step_index in range(nb_multicore_to_do): # results_multiview += Parallel(n_jobs=nb_cores)( -# delayed(exec_multiview_multicore)(directory, core_index, args["Base"]["name"], +# delayed(exec_multiview_multicore)(directory, core_index, args["name"], # classification_indices, k_folds, -# args["Base"]["type"], args["Base"]["pathf"], +# args["file_type"], args["Base"]["pathf"], # labels_dictionary, random_state, # labels, # hyper_param_search=hyper_param_search, # metrics=metrics, -# n_iter=args["Classification"]["hps_iter"], +# n_iter=args["hps_iter"], # ** # argument_dictionaries["multiview"][ # step_index * nb_cores + core_index]) @@ -664,11 +664,11 @@ def exec_one_benchmark_mono_core(dataset_var=None, labels_dictionary=None, X = dataset_var.get_v(arguments["view_index"]) Y = dataset_var.get_labels() results_monoview += [ - exec_monoview(directory, X, Y, args["Base"]["name"], labels_names, + exec_monoview(directory, X, Y, args["name"], labels_names, classification_indices, k_folds, - 1, args["Base"]["type"], args["Base"]["pathf"], random_state, + 1, args["file_type"], args["pathf"], random_state, hyper_param_search=hyper_param_search, metrics=metrics, - n_iter=args["Classification"]["hps_iter"], **arguments)] + n_iter=args["hps_iter"], **arguments)] except: if track_tracebacks: traceback_outputs[arguments["classifier_name"]+"-"+arguments["view_name"]] = traceback.format_exc() @@ -692,11 +692,11 @@ def exec_one_benchmark_mono_core(dataset_var=None, labels_dictionary=None, for arguments in argument_dictionaries["multiview"]: try: results_multiview += [ - exec_multiview(directory, dataset_var, args["Base"]["name"], classification_indices, - k_folds, 1, args["Base"]["type"], - args["Base"]["pathf"], labels_dictionary, random_state, labels, + exec_multiview(directory, dataset_var, args["name"], classification_indices, + k_folds, 1, args["file_type"], + args["pathf"], labels_dictionary, random_state, labels, hyper_param_search=hyper_param_search, - metrics=metrics, n_iter=args["Classification"]["hps_iter"], **arguments)] + metrics=metrics, n_iter=args["hps_iter"], **arguments)] except: if track_tracebacks: traceback_outputs[arguments["classifier_name"]] = traceback.format_exc() @@ -810,63 +810,64 @@ def exec_classif(arguments): start = time.time() args = execution.parse_the_args(arguments) args = configuration.get_the_args(args.config_path) - os.nice(args["Base"]["nice"]) - nb_cores = args["Base"]["nb_cores"] + os.nice(args["nice"]) + nb_cores = args["nb_cores"] if nb_cores == 1: os.environ['OPENBLAS_NUM_THREADS'] = '1' - stats_iter = args["Classification"]["stats_iter"] - hyper_param_search = args["Classification"]["hps_type"] - multiclass_method = args["Classification"]["multiclass_method"] - cl_type = args["Classification"]["type"] - monoview_algos = args["Classification"]["algos_monoview"] - multiview_algos = args["Classification"]["algos_multiview"] - dataset_list = execution.find_dataset_names(args["Base"]["pathf"], - args["Base"]["type"], - args["Base"]["name"]) - if not args["Base"]["add_noise"]: - args["Base"]["noise_std"]=[0.0] - + stats_iter = args["stats_iter"] + hyper_param_search = args["hps_type"] + multiclass_method = args["multiclass_method"] + cl_type = args["type"] + monoview_algos = args["algos_monoview"] + multiview_algos = args["algos_multiview"] + dataset_list = execution.find_dataset_names(args["pathf"], + args["file_type"], + args["name"]) + if not args["add_noise"]: + args["noise_std"]=[0.0] + print(dataset_list) for dataset_name in dataset_list: noise_results = [] - for noise_std in args["Base"]["noise_std"]: + for noise_std in args["noise_std"]: + + directory = execution.init_log_file(dataset_name, args["views"], args["file_type"], + args["log"], args["debug"], args["label"], + args["res_dir"], args["add_noise"], noise_std, args) - directory = execution.init_log_file(dataset_name, args["Base"]["views"], args["Classification"]["type"], - args["Base"]["log"], args["Base"]["debug"], args["Base"]["label"], - args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std, args) - random_state = execution.init_random_state(args["Base"]["random_state"], directory) + random_state = execution.init_random_state(args["random_state"], directory) stats_iter_random_states = execution.init_stats_iter_random_states(stats_iter, random_state) - get_database = execution.get_database_function(dataset_name, args["Base"]["type"]) + get_database = execution.get_database_function(dataset_name, args["file_type"]) - dataset_var, labels_dictionary, datasetname = get_database(args["Base"]["views"], - args["Base"]["pathf"], dataset_name, - args["Classification"]["nb_class"], - args["Classification"]["classes"], + dataset_var, labels_dictionary, datasetname = get_database(args["views"], + args["pathf"], dataset_name, + args["nb_class"], + args["classes"], random_state, - args["Base"]["full"], - args["Base"]["add_noise"], + args["full"], + args["add_noise"], noise_std) - args["Base"]["name"] = datasetname + args["name"] = datasetname - splits = execution.gen_splits(dataset_var.get_labels(), args["Classification"]["split"], + splits = execution.gen_splits(dataset_var.get_labels(), args["split"], stats_iter_random_states) # multiclass_labels, labels_combinations, indices_multiclass = multiclass.gen_multiclass_labels( # dataset_var.get_labels(), multiclass_method, splits) - k_folds = execution.gen_k_folds(stats_iter, args["Classification"]["nb_folds"], + k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"], stats_iter_random_states) - dataset_files = dataset.init_multiple_datasets(args["Base"]["pathf"], args["Base"]["name"], nb_cores) + dataset_files = dataset.init_multiple_datasets(args["pathf"], args["name"], nb_cores) - views, views_indices, all_views = execution.init_views(dataset_var, args["Base"]["views"]) + views, views_indices, all_views = execution.init_views(dataset_var, args["views"]) views_dictionary = dataset_var.get_view_dict() nb_views = len(views) nb_class = dataset_var.get_nb_class() - metrics = [metric.split(":") for metric in args["Classification"]["metrics"]] + metrics = [metric.split(":") for metric in args["metrics"]] if metrics == [["all"]]: metrics_names = [name for _, name, isPackage in pkgutil.iter_modules( @@ -875,7 +876,7 @@ def exec_classif(arguments): "matthews_corrcoef", "roc_auc_score"]] metrics = [[metricName] for metricName in metrics_names] - metrics = arange_metrics(metrics, args["Classification"]["metric_princ"]) + metrics = arange_metrics(metrics, args["metric_princ"]) for metricIndex, metric in enumerate(metrics): if len(metric) == 1: metrics[metricIndex] = [metric[0], None] @@ -899,7 +900,7 @@ def exec_classif(arguments): results_mean_stds = exec_benchmark( nb_cores, stats_iter, benchmark_argument_dictionaries, directory, metrics, dataset_var, - args["Base"]["track_tracebacks"]) + args["track_tracebacks"]) noise_results.append([noise_std, results_mean_stds]) plot_results_noise(directory, noise_results, metrics[0][0], dataset_name) diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis.py index 3f1b93dbf9a25d9fa753341c48063977ae4149d2..c5017e786f61912a9c4bc133a059b9021a11078a 100644 --- a/multiview_platform/mono_multi_view_classifiers/result_analysis.py +++ b/multiview_platform/mono_multi_view_classifiers/result_analysis.py @@ -738,7 +738,7 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, feature_importances = get_feature_importances(result) directory = arguments["directory"] - database_name = arguments["args"]["Base"]["name"] + database_name = arguments["args"]["name"] labels_names = [arguments["labels_dictionary"][0], arguments["labels_dictionary"][1]] @@ -916,10 +916,10 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, # multiclass_labels) # # results = publishMulticlassScores(multiclass_results, metrics, stats_iter, directories, -# benchmark_argument_dictionaries[0]["args"]["Base"]["name"]) +# benchmark_argument_dictionaries[0]["args"]["name"]) # publishMulticlassExmapleErrors(multiclass_results, directories, # benchmark_argument_dictionaries[0][ -# "args"]["Base"]["name"], example_ids, multiclass_labels) +# "args"]["name"], example_ids, multiclass_labels) # # return results, multiclass_results @@ -982,14 +982,14 @@ def publish_all_example_errors(iter_results, directory, error_on_examples, classifier_names = gen_error_data_glob(iter_results, stats_iter) - np.savetxt(directory + "clf_errors.csv", data, delimiter=",") - np.savetxt(directory + "example_errors.csv", error_on_examples, + np.savetxt(os.path.join(directory, "clf_errors.csv"), data, delimiter=",") + np.savetxt(os.path.join(directory, "example_errors.csv"), error_on_examples, delimiter=",") plot_2d(data, classifier_names, nbClassifiers, nbExamples, - directory, stats_iter=stats_iter, example_ids=example_ids, labels=labels) + os.path.join(directory, ""), stats_iter=stats_iter, example_ids=example_ids, labels=labels) plot_errors_bar(error_on_examples, nbClassifiers * stats_iter, - nbExamples, directory) + nbExamples, os.path.join(directory, "")) logging.debug( "Done:\t Global biclass label analysis figures generation") @@ -1207,7 +1207,7 @@ def get_results(results, stats_iter, benchmark_argument_dictionaries, metrics, directory, example_ids, labels): """Used to analyze the results of the previous benchmarks""" - data_base_name = benchmark_argument_dictionaries[0]["args"]["Base"]["name"] + data_base_name = benchmark_argument_dictionaries[0]["args"]["name"] results_means_std, biclass_results, flagged_failed = analyze_iterations(results, benchmark_argument_dictionaries, diff --git a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py index 0f044c61ad0a0ef83083abb1d9a603ff57d4eb08..1d43abc8a67e2f4cc1efa4a4ae9855d6dd661ad0 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py @@ -18,7 +18,73 @@ def get_the_args(path_to_config_file="../config_files/config.yml"): """ with open(path_to_config_file, 'r') as stream: yaml_config = yaml.safe_load(stream) - return yaml_config + return pass_default_config(**yaml_config) + + +def pass_default_config(log=True, + name=["plausible",], + label="_", + file_type=".hdf5", + views=None, + pathf="../data/", + nice=0, + random_state=42, + nb_cores=1, + full=True, + debug=False, + add_noise=False, + noise_std=0.0, + res_dir="../results/", + track_tracebacks=False, + multiclass_method="oneVersusOne", + split=0.49, + nb_folds=5, + nb_class=None, + classes=None, + type=["multiview",], + algos_monoview=["all" ], + algos_multiview=["svm_jumbo_fusion",], + stats_iter=2, + metrics=["accuracy_score", "f1_score"], + metric_princ="f1_score", + hps_type="randomized_search", + hps_iter=1, **kwargs): + """ + + :param log: + :param name: + :param label: + :param file_type: + :param views: + :param pathf: + :param nice: + :param random_state: + :param nb_cores: + :param full: + :param debug: + :param add_noise: + :param noise_std: + :param res_dir: + :param track_tracebacks: + :param multiclass_method: + :param split: + :param nb_folds: + :param nb_class: + :param classes: + :param type: + :param algos_monoview: + :param algos_multiview: + :param stats_iter: + :param metrics: + :param metric_princ: + :param hps_type: + :param hps_iter: + :return: + """ + args = dict((key, value) for key, value in locals().items() if key !="kwargs") + args = dict(args, **kwargs) + return args + def save_config(directory, arguments): diff --git a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py index 3b34cdab2e5985a78b0768b6d45cb322f0dd6bb2..c75fa0fccb637e394172befd438c80daa28e3f6b 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py @@ -677,7 +677,7 @@ def delete_HDF5(benchmarkArgumentsDictionaries, nbCores, dataset): logging.debug("Start:\t Deleting datasets for multiprocessing") for coreIndex in range(nbCores): - os.remove(args["Base"]["pathf"] + args["Base"]["name"] + str(coreIndex) + ".hdf5") + os.remove(args["pathf"] + args["name"] + str(coreIndex) + ".hdf5") if dataset.is_temp: dataset.rm() diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py index ec5775af94fea8e48fa57e78924c4a9bf8eb78fe..08244d0e0d07103baaf66f0f87c1c7b911728a82 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py @@ -156,7 +156,7 @@ def init_log_file(name, views, cl_type, log, debug, label, result_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), result_directory) if debug: result_directory = os.path.join(result_directory, name, noise_string, - "debug_started_" + time.strftime("%Y_%m_%d-%H_%M_%S") + "_" + label) + "debug_started_" + time.strftime("%Y_%m_%d-%H_%M_%S") + "_" + label) else: result_directory = os.path.join(result_directory, name, noise_string, "started_" + time.strftime("%Y_%m_%d-%H_%M") + "_" + label) @@ -321,6 +321,9 @@ def find_dataset_names(path, type, names): if names == ["all"]: return available_file_names elif len(names)>1: + selected_names = [used_name for used_name in available_file_names if used_name in names] + if not selected_names: + raise ValueError("None of the provided dataset names are available. Available datasets are {}".format(available_file_names)) return [used_name for used_name in available_file_names if used_name in names] else: return names diff --git a/multiview_platform/tests/test_utils/test_configuration.py b/multiview_platform/tests/test_utils/test_configuration.py index 2d2b9f41ca6c4a105ee063bb5c5cf21435aea3a7..3498329fcb46b8e8239034b147573468ebc69b21 100644 --- a/multiview_platform/tests/test_utils/test_configuration.py +++ b/multiview_platform/tests/test_utils/test_configuration.py @@ -16,7 +16,7 @@ class Test_get_the_args(unittest.TestCase): path_file = os.path.dirname(os.path.abspath(__file__)) make_tmp_dir = os.path.join(path_file, "../tmp_tests") os.mkdir(make_tmp_dir) - data = {"Base":{"first_arg": 10, "second_arg":[12.5, 1e-06]}, "Classification":{"third_arg":True}} + data = {"log": 10, "name":[12.5, 1e-06], "type":True} with open(cls.path_to_config_file, "w") as config_file: yaml.dump(data, config_file) @@ -31,16 +31,14 @@ class Test_get_the_args(unittest.TestCase): def test_dict_format(self): config_dict = configuration.get_the_args(self.path_to_config_file) - self.assertIn("Base", config_dict) - self.assertIn("Classification", config_dict) - self.assertIn("first_arg", config_dict["Base"]) - self.assertIn("third_arg", config_dict["Classification"]) + self.assertIn("log", config_dict) + self.assertIn("name", config_dict) def test_arguments(self): config_dict = configuration.get_the_args(self.path_to_config_file) - self.assertEqual(config_dict["Base"]["first_arg"], 10) - self.assertEqual(config_dict["Base"]["second_arg"], [12.5, 1e-06]) - self.assertEqual(config_dict["Classification"]["third_arg"], True) + self.assertEqual(config_dict["log"], 10) + self.assertEqual(config_dict["name"], [12.5, 1e-06]) + self.assertEqual(config_dict["type"], True) # class Test_format_the_args(unittest.TestCase): # diff --git a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py index e1f6db063d3fe6fea97b957583db7796ae8610c8..5ee7c0bb62be208923493b10235c1a45c38cd106 100644 --- a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py +++ b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py @@ -150,6 +150,7 @@ class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase): self.assertEqual(RSCV.n_iter, self.n_iter) def test_fit_multiview_equiv(self): + self.n_iter=1 RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV( FakeEstimMV(), self.param_distributions, n_iter=self.n_iter, refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,