Commit 299a5389 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

compatible with sklearn 00.24

parent 89d4a781
Pipeline #6438 failed with stage
in 6 minutes and 8 seconds
......@@ -64,6 +64,12 @@ from multimodal.boosting.combo import MuComboClassifier
from multimodal.tests.data.get_dataset_path import get_dataset_path
from multimodal.datasets.data_sample import MultiModalArray
class NoSampleWeightLasso(Lasso):
def fit(self, X, y, check_input=True):
return Lasso.fit(self, X, y, sample_weight=None, check_input=True)
class TestMuComboClassifier(unittest.TestCase):
@classmethod
......@@ -836,7 +842,7 @@ class TestMuComboClassifier(unittest.TestCase):
#
def test_classifier(self):
return check_estimator(MuComboClassifier)
return check_estimator(MuComboClassifier())
#
#
# def test_iris():
......@@ -957,7 +963,8 @@ class TestMuComboClassifier(unittest.TestCase):
# # Check that using a base estimator that doesn't support sample_weight
# # raises an error.
clf = MuComboClassifier(Lasso())
clf = MuComboClassifier(NoSampleWeightLasso())
self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind)
# assert_raises(ValueError, clf.fit, iris.data, iris.target, iris.views_ind)
#
......
......@@ -57,6 +57,7 @@ from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from multimodal.boosting.mumbo import MumboClassifier
from multimodal.tests.test_combo import NoSampleWeightLasso
class TestMuCumboClassifier(unittest.TestCase):
......@@ -730,7 +731,7 @@ class TestMuCumboClassifier(unittest.TestCase):
# e = MumboClassifier()
# e.fit(X_zero_features, y)
# print(e.predict(X_zero_features))
return check_estimator(MumboClassifier)
return check_estimator(MumboClassifier())
def test_iris(self):
# Check consistency on dataset iris.
......@@ -840,7 +841,7 @@ class TestMuCumboClassifier(unittest.TestCase):
# Check that using a base estimator that doesn't support sample_weight
# raises an error.
clf = MumboClassifier(Lasso())
clf = MumboClassifier(NoSampleWeightLasso())
self.assertRaises(ValueError, clf.fit, self.iris.data, self.iris.target, self.iris.views_ind)
......
......@@ -176,7 +176,7 @@ def setup_package():
keywords = ('machine learning, supervised learning, classification, '
'ensemble methods, boosting, kernel')
packages = find_packages(exclude=['*.tests'])
install_requires = ['scikit-learn>=0.22', 'numpy', 'scipy', 'cvxopt' ]
install_requires = ['scikit-learn>=0.24', 'numpy', 'scipy', 'cvxopt' ]
python_requires = '>=3.5'
extras_require = {
'dev': ['pytest', 'pytest-cov'],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment