diff --git a/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py b/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py index 9d423f14a5152b64ff77c08bf19ed99ba94234c6..acc6a8fef00f2990b837d4f3fba11943826d4ef4 100644 --- a/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py +++ b/multiview_platform/tests/test_mono_view/test_ExecClassifMonoView.py @@ -128,19 +128,19 @@ class Test_getHPs(unittest.TestCase): os.rmdir(tmp_path) def test_simple(self): - kwargs, test_folds_predictions = exec_classif_mono_view.getHPs(self.classifierModule, - self.hyper_param_search, - self.n_iter, - self.classifier_name, - self.classifier_class_name, - self.X, - self.y, - self.random_state, - self.output_file_name, - self.cv, - self.nb_cores, - self.metrics, - self.kwargs) + kwargs = exec_classif_mono_view.getHPs(self.classifierModule, + self.hyper_param_search, + self.n_iter, + self.classifier_name, + self.classifier_class_name, + self.X, + self.y, + self.random_state, + self.output_file_name, + self.cv, + self.nb_cores, + self.metrics, + self.kwargs) # class Test_getKWARGS(unittest.TestCase): # diff --git a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py index 2257f14a4842ddea064e131bbb4f236b2f1a1d22..e1f6db063d3fe6fea97b957583db7796ae8610c8 100644 --- a/multiview_platform/tests/test_utils/test_hyper_parameter_search.py +++ b/multiview_platform/tests/test_utils/test_hyper_parameter_search.py @@ -4,7 +4,7 @@ import unittest import h5py import numpy as np from sklearn.model_selection import StratifiedKFold -from multiview_platform.tests.utils import rm_tmp, tmp_path +from multiview_platform.tests.utils import rm_tmp, tmp_path, test_dataset from multiview_platform.mono_multi_view_classifiers.utils.dataset import HDF5Dataset @@ -55,13 +55,116 @@ class Test_randomized_search(unittest.TestCase): def test_simple(self): - best_params, test_folds_preds = hyper_parameter_search.randomized_search( + best_params, _ = hyper_parameter_search.randomized_search( self.dataset, self.labels[()], "multiview", self.random_state, tmp_path, weighted_linear_early_fusion, "WeightedLinearEarlyFusion", self.k_folds, 1, ["accuracy_score", None], 2, {}, learning_indices=self.learning_indices) + self.assertIsInstance(best_params, dict) +from sklearn.base import BaseEstimator -if __name__ == '__main__': - # unittest.main() - suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search) - unittest.TextTestRunner(verbosity=2).run(suite) \ No newline at end of file +class FakeEstim(BaseEstimator): + def __init__(self, param1=None, param2=None): + self.param1 = param1 + self.param2 = param2 + + def fit(self, X, y,): + return self + + def predict(self, X): + return np.zeros(X.shape[0]) + +class FakeEstimMV(BaseEstimator): + def __init__(self, param1=None, param2=None): + self.param1 = param1 + self.param2 = param2 + + def fit(self, X, y,train_indices=None, view_indices=None): + return self + + def predict(self, X, example_indices=None, view_indices=None): + return np.zeros(example_indices.shape[0]) + +from sklearn.metrics import accuracy_score, make_scorer +from sklearn.model_selection import StratifiedKFold + +class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase): + + @classmethod + def setUpClass(cls): + n_splits=2 + cls.estimator = FakeEstim() + cls.param_distributions = {"param1":[10,100], "param2":[11, 101]} + cls.n_iter = 4 + cls.refit = True + cls.n_jobs = 1 + cls.scoring = make_scorer(accuracy_score, ) + cls.cv = StratifiedKFold(n_splits=n_splits, ) + cls.random_state = np.random.RandomState(42) + cls.learning_indices = np.array([0,1,2]) + cls.view_indices = None + cls.framework = "monoview" + cls.equivalent_draws = False + cls.X = cls.random_state.randint(0,100, (5,11)) + cls.y = cls.random_state.randint(0,1, 5) + + def test_simple(self): + hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV( + self.estimator, self.param_distributions, n_iter=self.n_iter, + refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv, + random_state=self.random_state, + learning_indices=self.learning_indices, view_indices=self.view_indices, + framework=self.framework, + equivalent_draws=self.equivalent_draws + ) + + def test_fit(self): + RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV( + self.estimator, self.param_distributions, n_iter=self.n_iter, + refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, + cv=self.cv, + random_state=self.random_state, + learning_indices=self.learning_indices, + view_indices=self.view_indices, + framework=self.framework, + equivalent_draws=self.equivalent_draws + ) + RSCV.fit(self.X, self.y, ) + tested_param1 = np.ma.masked_array(data=[10,10,100,100], + mask=[False, False, False, False]) + np.testing.assert_array_equal(RSCV.cv_results_['param_param1'], + tested_param1) + + def test_fit_multiview(self): + RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV( + FakeEstimMV(), self.param_distributions, n_iter=self.n_iter, + refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, + cv=self.cv, + random_state=self.random_state, + learning_indices=self.learning_indices, + view_indices=self.view_indices, + framework="multiview", + equivalent_draws=self.equivalent_draws + ) + RSCV.fit(test_dataset, self.y, ) + self.assertEqual(RSCV.n_iter, self.n_iter) + + def test_fit_multiview_equiv(self): + RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV( + FakeEstimMV(), self.param_distributions, n_iter=self.n_iter, + refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, + cv=self.cv, + random_state=self.random_state, + learning_indices=self.learning_indices, + view_indices=self.view_indices, + framework="multiview", + equivalent_draws=True + ) + RSCV.fit(test_dataset, self.y, ) + self.assertEqual(RSCV.n_iter, self.n_iter*test_dataset.nb_view) + + +# if __name__ == '__main__': +# # unittest.main() +# suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search) +# unittest.TextTestRunner(verbosity=2).run(suite) \ No newline at end of file