diff --git a/multimodal/tests/test_combo.py b/multimodal/tests/test_combo.py index c44a8d96aed1cac7a9081ef089df180c20a23a9e..f0588935bcfd2e8f435058ebf5d7ab468a196a4e 100644 --- a/multimodal/tests/test_combo.py +++ b/multimodal/tests/test_combo.py @@ -64,6 +64,12 @@ from multimodal.boosting.combo import MuComboClassifier from multimodal.tests.data.get_dataset_path import get_dataset_path from multimodal.datasets.data_sample import MultiModalArray +class NoSampleWeightLasso(Lasso): + + def fit(self, X, y, check_input=True): + return Lasso.fit(self, X, y, sample_weight=None, check_input=True) + + class TestMuComboClassifier(unittest.TestCase): @classmethod @@ -836,7 +842,7 @@ class TestMuComboClassifier(unittest.TestCase): # def test_classifier(self): - return check_estimator(MuComboClassifier) + return check_estimator(MuComboClassifier()) # # # def test_iris(): @@ -957,7 +963,8 @@ class TestMuComboClassifier(unittest.TestCase): # # Check that using a base estimator that doesn't support sample_weight # # raises an error. - clf = MuComboClassifier(Lasso()) + clf = MuComboClassifier(NoSampleWeightLasso()) + self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind) # assert_raises(ValueError, clf.fit, iris.data, iris.target, iris.views_ind) # diff --git a/multimodal/tests/test_mumbo.py b/multimodal/tests/test_mumbo.py index 38f846c7535c6efa9b5f97fec49a90dcd3c5dec8..7169c8b51c073947bea20b4fede3d4389f16af96 100644 --- a/multimodal/tests/test_mumbo.py +++ b/multimodal/tests/test_mumbo.py @@ -57,6 +57,7 @@ from sklearn.tree import DecisionTreeClassifier from sklearn import datasets from multimodal.boosting.mumbo import MumboClassifier +from multimodal.tests.test_combo import NoSampleWeightLasso class TestMuCumboClassifier(unittest.TestCase): @@ -730,7 +731,7 @@ class TestMuCumboClassifier(unittest.TestCase): # e = MumboClassifier() # e.fit(X_zero_features, y) # print(e.predict(X_zero_features)) - return check_estimator(MumboClassifier) + return check_estimator(MumboClassifier()) def test_iris(self): # Check consistency on dataset iris. @@ -840,7 +841,7 @@ class TestMuCumboClassifier(unittest.TestCase): # Check that using a base estimator that doesn't support sample_weight # raises an error. - clf = MumboClassifier(Lasso()) + clf = MumboClassifier(NoSampleWeightLasso()) self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind) diff --git a/setup.py b/setup.py index 62c350a954359f10c65c747680d4767beb522db1..a4e76c403cf5ce354add295f88c0130c6fbb751a 100644 --- a/setup.py +++ b/setup.py @@ -176,7 +176,7 @@ def setup_package(): keywords = ('machine learning, supervised learning, classification, ' 'ensemble methods, boosting, kernel') packages = find_packages(exclude=['*.tests']) - install_requires = ['scikit-learn>=0.22', 'numpy', 'scipy', 'cvxopt' ] + install_requires = ['scikit-learn>=0.24', 'numpy', 'scipy', 'cvxopt' ] python_requires = '>=3.5' extras_require = { 'dev': ['pytest', 'pytest-cov'],