diff --git a/config_files/config.yml b/config_files/config.yml index 5c3f48fc24d18b5e334592eea9446d6045da8efe..bcd99f24eea4b96b91abe4e562cb82304a02aaa9 100644 --- a/config_files/config.yml +++ b/config_files/config.yml @@ -137,4 +137,9 @@ min_cq: weighted_linear_early_fusion: view_weights: [None] - monoview_classifier: ["decision_tree"] + monoview_classifier_name: ["decision_tree"] + monoview_classifier_config: + decision_tree: + max_depth: [1] + criterion: ["gini"] + splitter: ["best"] diff --git a/ipynb/FeatureExtraction-All_unix.ipynb b/ipynb/FeatureExtraction-All_unix.ipynb index 42b4e2e64409c15503e9af8034c09798f00a90f1..fd972a80342a04b04fbd8a6b243fd211667c6e5e 100644 --- a/ipynb/FeatureExtraction-All_unix.ipynb +++ b/ipynb/FeatureExtraction-All_unix.ipynb @@ -480,7 +480,6 @@ { "ename": "ValueError", "evalue": "all the input array dimensions except for the concatenation axis must match exactly", - "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", @@ -488,7 +487,8 @@ "\u001b[1;32m<ipython-input-13-c8d37ffc0446>\u001b[0m in \u001b[0;36mcalcSurfHisto\u001b[1;34m(dfImages_, k_)\u001b[0m\n\u001b[0;32m 39\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mdescriptor\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdes_list\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 40\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 41\u001b[1;33m \u001b[0mdescriptors\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdescriptors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdescriptor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 42\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 43\u001b[0m \u001b[1;31m#### Bag of Words Approach\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/doob/anaconda2/lib/python2.7/site-packages/numpy/core/shape_base.pyc\u001b[0m in \u001b[0;36mvstack\u001b[1;34m(tup)\u001b[0m\n\u001b[0;32m 228\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 229\u001b[0m \"\"\"\n\u001b[1;32m--> 230\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_nx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0matleast_2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_m\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0m_m\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtup\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 231\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 232\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mhstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mValueError\u001b[0m: all the input array dimensions except for the concatenation axis must match exactly" - ] + ], + "output_type": "error" } ], "source": [ diff --git a/multiview_platform/mono_multi_view_classifiers/exec_classif.py b/multiview_platform/mono_multi_view_classifiers/exec_classif.py index cd066e7274fbda67eeffe406339c0decae445a55..8f8ebb7ece1cb8b2e67b744d8eae4725adae17e7 100644 --- a/multiview_platform/mono_multi_view_classifiers/exec_classif.py +++ b/multiview_platform/mono_multi_view_classifiers/exec_classif.py @@ -200,7 +200,7 @@ def gen_single_multiview_arg_dictionary(classifier_name,arguments,nb_class, "nb_class": nb_class, "labels_names": None, classifier_name: dict((key, value[0]) for key, value in arguments[ - classifier_name].items()) + classifier_name].items() if isinstance(value, list)) } diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py index 3b2952e4481e0ca2bf1a4510751c742f7ef2699e..e079dae7b54fda07d3e036bf21d0d762f6f47d30 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py @@ -84,12 +84,6 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): np.array([self.train_time, self.pred_time]), delimiter=',') return interpretString -# -# def formatCmdArgs(args): -# """Used to format kwargs for the parsed args""" -# kwargsDict = {'n_estimators': args.Ada_n_est, -# 'base_estimator': [DecisionTreeClassifier(max_depth=1)]} -# return kwargsDict def paramsToSet(nIter, random_state): diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py index 51a9a5686f1ccb3ac3f9d6b05546eadfeb10859b..dd4767d557d8b05d4d1d8a3676d9d162e1e12364 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py @@ -97,21 +97,20 @@ def get_train_views_indices(dataset, train_indices, view_indices,): class ConfigGenerator(): - def __init__(self): + def __init__(self, classifier_names): self.distribs = {} - for name in dir(monoview_classifiers): - if not name.startswith("__"): - module = getattr(monoview_classifiers, name) - classifier_class = getattr(module, - module.classifier_class_name)() - self.distribs[name] = dict((param_name, param_distrib) - for param_name, param_distrib in - zip(classifier_class().param_names, - classifier_class().distribs)) + for classifier_name in classifier_names: + classifier_module = getattr(monoview_classifiers, classifier_name) + classifier_class = getattr(classifier_module, classifier_module.classifier_class_name) + self.distribs[classifier_name] = dict((param_name, param_distrib) + for param_name, param_distrib in + zip(classifier_class().param_names, + classifier_class().distribs)) def rvs(self, random_state=None): config_sample = {} for classifier_name, classifier_config in self.distribs.items(): + config_sample[classifier_name] = {} for param_name, param_distrib in classifier_config.items(): if hasattr(param_distrib, "rvs"): config_sample[classifier_name][param_name]=param_distrib.rvs(random_state=random_state) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py index 303f52157a858dde1d1f0a6769a9d139de0edc90..25539a9f90982fb05de991587693f7f123bb29bb 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py @@ -1,5 +1,5 @@ import numpy as np -import pkgutil +import inspect from ..utils.dataset import getV from ..multiview.multiview_utils import BaseMultiviewClassifier, get_train_views_indices, ConfigGenerator @@ -11,39 +11,46 @@ classifier_class_name = "WeightedLinearEarlyFusion" class WeightedLinearEarlyFusion(BaseMultiviewClassifier): def __init__(self, random_state=None, view_weights=None, - monoview_classifier="decision_tree", + monoview_classifier_name="decision_tree", monoview_classifier_config={}): super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state) self.view_weights = view_weights - if isinstance(monoview_classifier, str): - self.short_name = "early fusion "+monoview_classifier - monoview_classifier_module = getattr(monoview_classifiers, - monoview_classifier) - monoview_classifier_class = getattr(monoview_classifier_module, - monoview_classifier_module.classifier_class_name) - self.monoview_classifier = monoview_classifier_class(random_state=random_state, - **monoview_classifier_config) - else: - self.monoview_classifier = monoview_classifier(monoview_classifier_config) - self.short_name = "early fusion "+self.monoview_classifier.__class__.__name__ - - self.param_names = ["monoview_classifier","random_state", "monoview_classifier_config"] - classifier_classes = [] - for name in dir(monoview_classifiers): - if not name.startswith("__"): - module = getattr(monoview_classifiers, name) - classifier_class = getattr(module, module.classifier_class_name) - classifier_classes.append(classifier_class) - self.distribs = [classifier_classes, [self.random_state], ConfigGenerator()] - self.classed_params = ["monoview_classifier"] - self.weird_strings={"monoview_classifier":["class_name", "config"]} - - def set_params(self, monoview_classifier=None, monoview_classifier_config=None, **params): - monoview_classifier_name = monoview_classifier.__module__ - self.monoview_classifier = monoview_classifier() + self.monoview_classifier_name = monoview_classifier_name + self.short_name = "early fusion " + monoview_classifier_name + self.monoview_classifier_config = monoview_classifier_config + + monoview_classifier_module = getattr(monoview_classifiers, + self.monoview_classifier_name) + monoview_classifier_class = getattr(monoview_classifier_module, + monoview_classifier_module.classifier_class_name) + self.monoview_classifier = monoview_classifier_class(random_state=random_state, + **self.monoview_classifier_config) + + self.param_names = ["monoview_classifier_name", "monoview_classifier_config"] + classifier_names = [] + for module_name in dir(monoview_classifiers): + if not module_name.startswith("__"): + classifier_names.append(module_name) + self.distribs = [classifier_names, ConfigGenerator(classifier_names)] + self.classed_params = [] + self.weird_strings={} + + def set_params(self, monoview_classifier_name=None, monoview_classifier_config=None, **params): + self.monoview_classifier_name = monoview_classifier_name + monoview_classifier_module = getattr(monoview_classifiers, + self.monoview_classifier_name) + monoview_classifier_class = getattr(monoview_classifier_module, + monoview_classifier_module.classifier_class_name) + self.monoview_classifier = monoview_classifier_class() self.set_monoview_classifier_config(monoview_classifier_name, monoview_classifier_config) + return self + def get_params(self, deep=True): + return {"random_state":self.random_state, + "view_weights":self.view_weights, + "monoview_classifier_name":self.monoview_classifier_name, + "monoview_classifier_config":self.monoview_classifier_config} def fit(self, X, y, train_indices=None, view_indices=None): train_indices, X = self.transform_data_to_monoview(X, train_indices, view_indices) @@ -80,9 +87,9 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier): def set_monoview_classifier_config(self, monoview_classifier_name, monoview_classifier_config): if monoview_classifier_name in monoview_classifier_config: - self.monoview_classifier.set_params(monoview_classifier_config[monoview_classifier_name]) + self.monoview_classifier.set_params(**monoview_classifier_config[monoview_classifier_name]) else: - self.monoview_classifier.set_params(monoview_classifier_config) + self.monoview_classifier.set_params(**monoview_classifier_config) diff --git a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py index 4534c685529b9978ba135ad54121f3b42340b737..72318b2a1cdac0f246193e5e33e39bdf120f2487 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py @@ -1,4 +1,3 @@ -import configparser import builtins from distutils.util import strtobool as tobool import yaml @@ -9,34 +8,3 @@ def get_the_args(path_to_config_file="../config_files/config.yml"): with open(path_to_config_file, 'r') as stream: yaml_config = yaml.safe_load(stream) return yaml_config - - # config_parser = configparser.ConfigParser(comment_prefixes=('#')) - # config_parser.read(path_to_config_file) - # config_dict = {} - # for section in config_parser: - # config_dict[section] = {} - # for key in config_parser[section]: - # value = format_raw_arg(config_parser[section][key]) - # config_dict[section][key] = value - # return config_dict - - -def format_raw_arg(raw_arg): - """This function is used to convert the raw arg in a types value. - For example, 'list_int ; 10 20' will be formatted in [10,20]""" - function_name, raw_value = raw_arg.split(" ; ") - if function_name.startswith("list"): - function_name = function_name.split("_")[1] - raw_values = raw_value.split(" ") - value = [getattr(builtins, function_name)(raw_value) - if function_name != "bool" else bool(tobool(raw_value)) - for raw_value in raw_values] - else: - if raw_value == "None": - value = None - else: - if function_name=="bool": - value = bool(tobool(raw_value)) - else: - value = getattr(builtins, function_name)(raw_value) - return value diff --git a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py index 6a19674c653148ab554121ba41f8f31c66cf9904..2a37b3d12408777f022309c12706d174ea2abcc1 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py @@ -10,6 +10,8 @@ from scipy import sparse from . import get_multiview_db as DB + + def getV(DATASET, viewIndex, usedIndices=None): """Used to extract a view as a numpy array or a sparse mat from the HDF5 dataset""" if usedIndices is None: diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py index 6c066e1cdcf6ff4753179e2f138278f13c152185..6bbf2eb4acb77a67d069bcbb25fa86710ea640d6 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py @@ -107,7 +107,7 @@ def randomized_search(X, y, framework, random_state, output_file_name, classifie min_list = np.array( [min(nb_possible_combination, n_iter) for nb_possible_combination in nb_possible_combinations]) - randomSearch = MultiviewCompatibleRandomizedSearchCV(estimator, + random_search = MultiviewCompatibleRandomizedSearchCV(estimator, n_iter=int(np.sum(min_list)), param_distributions=params_dict, refit=True, @@ -116,23 +116,25 @@ def randomized_search(X, y, framework, random_state, output_file_name, classifie learning_indices=learning_indices, view_indices=view_indices, framework = framework) - detector = randomSearch.fit(X, y) - - bestParams = dict((key, value) for key, value in - estimator.genBestParams(detector).items() if - key is not "random_state") - - scoresArray = detector.cv_results_['mean_test_score'] - params = estimator.genParamsFromDetector(detector) - - genHeatMaps(params, scoresArray, output_file_name) - best_estimator = detector.best_estimator_ + random_search.fit(X, y) + best_params = random_search.best_params_ + if "random_state" in best_params: + best_params.pop("random_state") + + # bestParams = dict((key, value) for key, value in + # estimator.genBestParams(detector).items() if + # key is not "random_state") + + scoresArray = random_search.cv_results_['mean_test_score'] + params = [(key[6:], value ) for key, value in random_search.cv_results_.items() if key.startswith("param_")] + # genHeatMaps(params, scoresArray, output_file_name) + best_estimator = random_search.best_estimator_ else: best_estimator = estimator - bestParams = {} + best_params = {} testFoldsPreds = get_test_folds_preds(X, y, folds, best_estimator, framework, learning_indices) - return bestParams, testFoldsPreds + return best_params, testFoldsPreds from sklearn.base import clone @@ -192,6 +194,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV): self.best_score_ = cross_validation_score if self.refit: self.best_estimator_ = clone(base_estimator).set_params(**self.best_params_) + self.best_estimator_.fit(X, y, **fit_params) self.n_splits_ = n_splits return self diff --git a/multiview_platform/tests/test_ExecClassif.py b/multiview_platform/tests/test_ExecClassif.py index cd9545cb9371f1b61afa9e709a919fb0f758f12b..1af0dc200460ac4b1b3ef60549df5a6969424438 100644 --- a/multiview_platform/tests/test_ExecClassif.py +++ b/multiview_platform/tests/test_ExecClassif.py @@ -4,6 +4,8 @@ import unittest import h5py import numpy as np +from .utils import rm_tmp + from ..mono_multi_view_classifiers import exec_classif @@ -25,6 +27,7 @@ class Test_initKWARGS(unittest.TestCase): class Test_init_argument_dictionaries(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() cls.benchmark = {"monoview": ["fake_monoview_classifier"], "multiview": {}} cls.views_dictionnary = {'test_view_0': 0, 'test_view': 1} cls.nb_class = 2 @@ -113,6 +116,7 @@ class Test_execBenchmark(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() os.mkdir("multiview_platform/tests/tmp_tests") cls.Dataset = h5py.File( "multiview_platform/tests/tmp_tests/test_file.hdf5", "w") @@ -216,6 +220,7 @@ class Test_execOneBenchmark(unittest.TestCase): @classmethod def setUp(cls): + rm_tmp() os.mkdir("multiview_platform/tests/tmp_tests") cls.args = { "Base": {"name": "chicken_is_heaven", "type": "type", @@ -282,6 +287,7 @@ class Test_execOneBenchmark_multicore(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() os.mkdir("multiview_platform/tests/tmp_tests") cls.args = { "Base": {"name": "chicken_is_heaven", "type": "type", diff --git a/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py b/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py index 78a4bdb2a570e48aadef94a4b9138dcbd74bc7f4..c5e70b74c7aadddf7b7d93bef5030e2911d1f647 100644 --- a/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py +++ b/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py @@ -5,6 +5,8 @@ import h5py import numpy as np from sklearn.model_selection import StratifiedKFold +from ..utils import rm_tmp + from ...mono_multi_view_classifiers.monoview import exec_classif_mono_view from ...mono_multi_view_classifiers.monoview_classifiers import decision_tree @@ -13,6 +15,7 @@ class Test_initConstants(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() os.mkdir("multiview_platform/tests/temp_tests") cls.datasetFile = h5py.File( "multiview_platform/tests/temp_tests/test.hdf5", "w") @@ -65,6 +68,7 @@ class Test_initTrainTest(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() cls.random_state = np.random.RandomState(42) cls.X = cls.random_state.randint(0, 500, (10, 5)) cls.Y = cls.random_state.randint(0, 2, 10) @@ -95,6 +99,7 @@ class Test_getHPs(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() os.mkdir("multiview_platform/tests/tmp_tests") cls.classifierModule = decision_tree cls.hyperParamSearch = "randomized_search" diff --git a/multiview_platform/tests/test_multiview_classifiers/test_diversity_utils.py b/multiview_platform/tests/test_multiview_classifiers/test_diversity_utils.py index 1b70609141096f5d6d3870393cdffb2334a0a19b..11b74a75abc59b18e5d2dd8a80537a3811fd9155 100644 --- a/multiview_platform/tests/test_multiview_classifiers/test_diversity_utils.py +++ b/multiview_platform/tests/test_multiview_classifiers/test_diversity_utils.py @@ -2,6 +2,7 @@ import unittest import numpy as np +from ..utils import rm_tmp from multiview_platform.mono_multi_view_classifiers.multiview.additions import \ diversity_utils @@ -14,6 +15,7 @@ class Test_global_div_measure(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() cls.randomState = np.random.RandomState(42) cls.allClassifiersNames = [["SCM", "SVM", "DT"], ["SCM", "SVM", "DT"]] cls.viewsIndices = np.array([0, 1]) diff --git a/multiview_platform/tests/test_multiview_classifiers/test_weighted_linear_early_fusion.py b/multiview_platform/tests/test_multiview_classifiers/test_weighted_linear_early_fusion.py index d78acf8020081205e42e36d79cc936d8511be72c..3fb36b314d67184ed43c801ab8e8e355d9ff24d3 100644 --- a/multiview_platform/tests/test_multiview_classifiers/test_weighted_linear_early_fusion.py +++ b/multiview_platform/tests/test_multiview_classifiers/test_weighted_linear_early_fusion.py @@ -4,6 +4,8 @@ import numpy as np import h5py import os +from ..utils import rm_tmp + from multiview_platform.mono_multi_view_classifiers.multiview_classifiers import \ weighted_linear_early_fusion @@ -11,6 +13,7 @@ class Test_WeightedLinearEarlyFusion(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() cls.random_state = np.random.RandomState(42) cls.view_weights = [0.5, 0.5] os.mkdir("multiview_platform/tests/tmp_tests") @@ -32,7 +35,7 @@ class Test_WeightedLinearEarlyFusion(unittest.TestCase): cls.monoview_classifier_config = {"max_depth":1, "criterion": "gini", "splitter": "best"} cls.classifier = weighted_linear_early_fusion.WeightedLinearEarlyFusion( random_state=cls.random_state, view_weights=cls.view_weights, - monoview_classifier=cls.monoview_classifier_name, + monoview_classifier_name=cls.monoview_classifier_name, monoview_classifier_config=cls.monoview_classifier_config) @classmethod diff --git a/multiview_platform/tests/test_utils/test_configuration.py b/multiview_platform/tests/test_utils/test_configuration.py index c1e8c3b47125380c120e8516c6880f115b6f6bc4..bc922be1e50733cc8dd61e343e2cd3f73d35ac45 100644 --- a/multiview_platform/tests/test_utils/test_configuration.py +++ b/multiview_platform/tests/test_utils/test_configuration.py @@ -3,11 +3,14 @@ import unittest import yaml import numpy as np +from ..utils import rm_tmp + from multiview_platform.mono_multi_view_classifiers.utils import configuration class Test_get_the_args(unittest.TestCase): def setUp(self): + rm_tmp() self.path_to_config_file = "multiview_platform/tests/tmp_tests/config_temp.yml" os.mkdir("multiview_platform/tests/tmp_tests") data = {"Base":{"first_arg": 10, "second_arg":[12.5, 1e-06]}, "Classification":{"third_arg":True}} diff --git a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py index 0024a1427a85b07adbbd4f4ebee038fcf75cc28d..6b207372e6bf3fee23b5c1cf005b427d77ab0044 100644 --- a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py +++ b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py @@ -5,6 +5,8 @@ import h5py import numpy as np from sklearn.model_selection import StratifiedKFold +from ..utils import rm_tmp + from ...mono_multi_view_classifiers.utils import hyper_parameter_search from ...mono_multi_view_classifiers.multiview_classifiers import weighted_linear_early_fusion @@ -12,6 +14,7 @@ class Test_randomized_search(unittest.TestCase): @classmethod def setUpClass(cls): + rm_tmp() cls.random_state = np.random.RandomState(42) cls.view_weights = [0.5, 0.5] os.mkdir("multiview_platform/tests/tmp_tests") diff --git a/multiview_platform/tests/utils.py b/multiview_platform/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5766b68885035b216a141caa8e273c0df74583ae --- /dev/null +++ b/multiview_platform/tests/utils.py @@ -0,0 +1,9 @@ +import os + +def rm_tmp(): + try: + for file_name in os.listdir("multiview_platform/tests/tmp_tests"): + os.remove(os.path.join("multiview_platform/tests/tmp_tests", file_name)) + os.rmdir("multiview_platform/tests/tmp_tests") + except: + pass