Skip to content
Snippets Groups Projects
Commit 54bbc20f authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Testing more stuff"

parent a723be1d
Branches
Tags
No related merge requests found
Pipeline #4093 failed
...@@ -128,7 +128,7 @@ class Test_getHPs(unittest.TestCase): ...@@ -128,7 +128,7 @@ class Test_getHPs(unittest.TestCase):
os.rmdir(tmp_path) os.rmdir(tmp_path)
def test_simple(self): def test_simple(self):
kwargs, test_folds_predictions = exec_classif_mono_view.getHPs(self.classifierModule, kwargs = exec_classif_mono_view.getHPs(self.classifierModule,
self.hyper_param_search, self.hyper_param_search,
self.n_iter, self.n_iter,
self.classifier_name, self.classifier_name,
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
import h5py import h5py
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
from multiview_platform.tests.utils import rm_tmp, tmp_path from multiview_platform.tests.utils import rm_tmp, tmp_path, test_dataset
from multiview_platform.mono_multi_view_classifiers.utils.dataset import HDF5Dataset from multiview_platform.mono_multi_view_classifiers.utils.dataset import HDF5Dataset
...@@ -55,13 +55,116 @@ class Test_randomized_search(unittest.TestCase): ...@@ -55,13 +55,116 @@ class Test_randomized_search(unittest.TestCase):
def test_simple(self): def test_simple(self):
best_params, test_folds_preds = hyper_parameter_search.randomized_search( best_params, _ = hyper_parameter_search.randomized_search(
self.dataset, self.labels[()], "multiview", self.random_state, tmp_path, self.dataset, self.labels[()], "multiview", self.random_state, tmp_path,
weighted_linear_early_fusion, "WeightedLinearEarlyFusion", self.k_folds, weighted_linear_early_fusion, "WeightedLinearEarlyFusion", self.k_folds,
1, ["accuracy_score", None], 2, {}, learning_indices=self.learning_indices) 1, ["accuracy_score", None], 2, {}, learning_indices=self.learning_indices)
self.assertIsInstance(best_params, dict)
from sklearn.base import BaseEstimator
if __name__ == '__main__': class FakeEstim(BaseEstimator):
# unittest.main() def __init__(self, param1=None, param2=None):
suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search) self.param1 = param1
unittest.TextTestRunner(verbosity=2).run(suite) self.param2 = param2
\ No newline at end of file
def fit(self, X, y,):
return self
def predict(self, X):
return np.zeros(X.shape[0])
class FakeEstimMV(BaseEstimator):
def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
def fit(self, X, y,train_indices=None, view_indices=None):
return self
def predict(self, X, example_indices=None, view_indices=None):
return np.zeros(example_indices.shape[0])
from sklearn.metrics import accuracy_score, make_scorer
from sklearn.model_selection import StratifiedKFold
class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase):
@classmethod
def setUpClass(cls):
n_splits=2
cls.estimator = FakeEstim()
cls.param_distributions = {"param1":[10,100], "param2":[11, 101]}
cls.n_iter = 4
cls.refit = True
cls.n_jobs = 1
cls.scoring = make_scorer(accuracy_score, )
cls.cv = StratifiedKFold(n_splits=n_splits, )
cls.random_state = np.random.RandomState(42)
cls.learning_indices = np.array([0,1,2])
cls.view_indices = None
cls.framework = "monoview"
cls.equivalent_draws = False
cls.X = cls.random_state.randint(0,100, (5,11))
cls.y = cls.random_state.randint(0,1, 5)
def test_simple(self):
hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
self.estimator, self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices, view_indices=self.view_indices,
framework=self.framework,
equivalent_draws=self.equivalent_draws
)
def test_fit(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
self.estimator, self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework=self.framework,
equivalent_draws=self.equivalent_draws
)
RSCV.fit(self.X, self.y, )
tested_param1 = np.ma.masked_array(data=[10,10,100,100],
mask=[False, False, False, False])
np.testing.assert_array_equal(RSCV.cv_results_['param_param1'],
tested_param1)
def test_fit_multiview(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
FakeEstimMV(), self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework="multiview",
equivalent_draws=self.equivalent_draws
)
RSCV.fit(test_dataset, self.y, )
self.assertEqual(RSCV.n_iter, self.n_iter)
def test_fit_multiview_equiv(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
FakeEstimMV(), self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework="multiview",
equivalent_draws=True
)
RSCV.fit(test_dataset, self.y, )
self.assertEqual(RSCV.n_iter, self.n_iter*test_dataset.nb_view)
# if __name__ == '__main__':
# # unittest.main()
# suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search)
# unittest.TextTestRunner(verbosity=2).run(suite)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment