Skip to content
Snippets Groups Projects
Commit 54bbc20f authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Testing more stuff"

parent a723be1d
No related branches found
No related tags found
No related merge requests found
Pipeline #4093 failed
......@@ -128,7 +128,7 @@ class Test_getHPs(unittest.TestCase):
os.rmdir(tmp_path)
def test_simple(self):
kwargs, test_folds_predictions = exec_classif_mono_view.getHPs(self.classifierModule,
kwargs = exec_classif_mono_view.getHPs(self.classifierModule,
self.hyper_param_search,
self.n_iter,
self.classifier_name,
......
......@@ -4,7 +4,7 @@ import unittest
import h5py
import numpy as np
from sklearn.model_selection import StratifiedKFold
from multiview_platform.tests.utils import rm_tmp, tmp_path
from multiview_platform.tests.utils import rm_tmp, tmp_path, test_dataset
from multiview_platform.mono_multi_view_classifiers.utils.dataset import HDF5Dataset
......@@ -55,13 +55,116 @@ class Test_randomized_search(unittest.TestCase):
def test_simple(self):
best_params, test_folds_preds = hyper_parameter_search.randomized_search(
best_params, _ = hyper_parameter_search.randomized_search(
self.dataset, self.labels[()], "multiview", self.random_state, tmp_path,
weighted_linear_early_fusion, "WeightedLinearEarlyFusion", self.k_folds,
1, ["accuracy_score", None], 2, {}, learning_indices=self.learning_indices)
self.assertIsInstance(best_params, dict)
from sklearn.base import BaseEstimator
if __name__ == '__main__':
# unittest.main()
suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search)
unittest.TextTestRunner(verbosity=2).run(suite)
\ No newline at end of file
class FakeEstim(BaseEstimator):
def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
def fit(self, X, y,):
return self
def predict(self, X):
return np.zeros(X.shape[0])
class FakeEstimMV(BaseEstimator):
def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
def fit(self, X, y,train_indices=None, view_indices=None):
return self
def predict(self, X, example_indices=None, view_indices=None):
return np.zeros(example_indices.shape[0])
from sklearn.metrics import accuracy_score, make_scorer
from sklearn.model_selection import StratifiedKFold
class Test_MultiviewCompatibleRandomizedSearchCV(unittest.TestCase):
@classmethod
def setUpClass(cls):
n_splits=2
cls.estimator = FakeEstim()
cls.param_distributions = {"param1":[10,100], "param2":[11, 101]}
cls.n_iter = 4
cls.refit = True
cls.n_jobs = 1
cls.scoring = make_scorer(accuracy_score, )
cls.cv = StratifiedKFold(n_splits=n_splits, )
cls.random_state = np.random.RandomState(42)
cls.learning_indices = np.array([0,1,2])
cls.view_indices = None
cls.framework = "monoview"
cls.equivalent_draws = False
cls.X = cls.random_state.randint(0,100, (5,11))
cls.y = cls.random_state.randint(0,1, 5)
def test_simple(self):
hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
self.estimator, self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring, cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices, view_indices=self.view_indices,
framework=self.framework,
equivalent_draws=self.equivalent_draws
)
def test_fit(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
self.estimator, self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework=self.framework,
equivalent_draws=self.equivalent_draws
)
RSCV.fit(self.X, self.y, )
tested_param1 = np.ma.masked_array(data=[10,10,100,100],
mask=[False, False, False, False])
np.testing.assert_array_equal(RSCV.cv_results_['param_param1'],
tested_param1)
def test_fit_multiview(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
FakeEstimMV(), self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework="multiview",
equivalent_draws=self.equivalent_draws
)
RSCV.fit(test_dataset, self.y, )
self.assertEqual(RSCV.n_iter, self.n_iter)
def test_fit_multiview_equiv(self):
RSCV = hyper_parameter_search.MultiviewCompatibleRandomizedSearchCV(
FakeEstimMV(), self.param_distributions, n_iter=self.n_iter,
refit=self.refit, n_jobs=self.n_jobs, scoring=self.scoring,
cv=self.cv,
random_state=self.random_state,
learning_indices=self.learning_indices,
view_indices=self.view_indices,
framework="multiview",
equivalent_draws=True
)
RSCV.fit(test_dataset, self.y, )
self.assertEqual(RSCV.n_iter, self.n_iter*test_dataset.nb_view)
# if __name__ == '__main__':
# # unittest.main()
# suite = unittest.TestLoader().loadTestsFromTestCase(Test_randomized_search)
# unittest.TextTestRunner(verbosity=2).run(suite)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment