Skip to content
Snippets Groups Projects
Commit 299a5389 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

compatible with sklearn 00.24

parent 89d4a781
No related branches found
No related tags found
No related merge requests found
Pipeline #6438 failed
...@@ -64,6 +64,12 @@ from multimodal.boosting.combo import MuComboClassifier ...@@ -64,6 +64,12 @@ from multimodal.boosting.combo import MuComboClassifier
from multimodal.tests.data.get_dataset_path import get_dataset_path from multimodal.tests.data.get_dataset_path import get_dataset_path
from multimodal.datasets.data_sample import MultiModalArray from multimodal.datasets.data_sample import MultiModalArray
class NoSampleWeightLasso(Lasso):
def fit(self, X, y, check_input=True):
return Lasso.fit(self, X, y, sample_weight=None, check_input=True)
class TestMuComboClassifier(unittest.TestCase): class TestMuComboClassifier(unittest.TestCase):
@classmethod @classmethod
...@@ -836,7 +842,7 @@ class TestMuComboClassifier(unittest.TestCase): ...@@ -836,7 +842,7 @@ class TestMuComboClassifier(unittest.TestCase):
# #
def test_classifier(self): def test_classifier(self):
return check_estimator(MuComboClassifier) return check_estimator(MuComboClassifier())
# #
# #
# def test_iris(): # def test_iris():
...@@ -957,7 +963,8 @@ class TestMuComboClassifier(unittest.TestCase): ...@@ -957,7 +963,8 @@ class TestMuComboClassifier(unittest.TestCase):
# # Check that using a base estimator that doesn't support sample_weight # # Check that using a base estimator that doesn't support sample_weight
# # raises an error. # # raises an error.
clf = MuComboClassifier(Lasso()) clf = MuComboClassifier(NoSampleWeightLasso())
self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind) self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind)
# assert_raises(ValueError, clf.fit, iris.data, iris.target, iris.views_ind) # assert_raises(ValueError, clf.fit, iris.data, iris.target, iris.views_ind)
# #
......
...@@ -57,6 +57,7 @@ from sklearn.tree import DecisionTreeClassifier ...@@ -57,6 +57,7 @@ from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets from sklearn import datasets
from multimodal.boosting.mumbo import MumboClassifier from multimodal.boosting.mumbo import MumboClassifier
from multimodal.tests.test_combo import NoSampleWeightLasso
class TestMuCumboClassifier(unittest.TestCase): class TestMuCumboClassifier(unittest.TestCase):
...@@ -730,7 +731,7 @@ class TestMuCumboClassifier(unittest.TestCase): ...@@ -730,7 +731,7 @@ class TestMuCumboClassifier(unittest.TestCase):
# e = MumboClassifier() # e = MumboClassifier()
# e.fit(X_zero_features, y) # e.fit(X_zero_features, y)
# print(e.predict(X_zero_features)) # print(e.predict(X_zero_features))
return check_estimator(MumboClassifier) return check_estimator(MumboClassifier())
def test_iris(self): def test_iris(self):
# Check consistency on dataset iris. # Check consistency on dataset iris.
...@@ -840,7 +841,7 @@ class TestMuCumboClassifier(unittest.TestCase): ...@@ -840,7 +841,7 @@ class TestMuCumboClassifier(unittest.TestCase):
# Check that using a base estimator that doesn't support sample_weight # Check that using a base estimator that doesn't support sample_weight
# raises an error. # raises an error.
clf = MumboClassifier(Lasso()) clf = MumboClassifier(NoSampleWeightLasso())
self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind) self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind)
......
...@@ -176,7 +176,7 @@ def setup_package(): ...@@ -176,7 +176,7 @@ def setup_package():
keywords = ('machine learning, supervised learning, classification, ' keywords = ('machine learning, supervised learning, classification, '
'ensemble methods, boosting, kernel') 'ensemble methods, boosting, kernel')
packages = find_packages(exclude=['*.tests']) packages = find_packages(exclude=['*.tests'])
install_requires = ['scikit-learn>=0.22', 'numpy', 'scipy', 'cvxopt' ] install_requires = ['scikit-learn>=0.24', 'numpy', 'scipy', 'cvxopt' ]
python_requires = '>=3.5' python_requires = '>=3.5'
extras_require = { extras_require = {
'dev': ['pytest', 'pytest-cov'], 'dev': ['pytest', 'pytest-cov'],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment