Skip to content
Snippets Groups Projects
Commit e3e810cd authored by Dominique Benielli's avatar Dominique Benielli
Browse files

fix doc test bug

parent c6616969
No related branches found
No related tags found
No related merge requests found
Pipeline #3981 failed
......@@ -361,7 +361,10 @@ class MultiModalArray(np.ndarray, MultiModalData):
try:
new_data = np.asarray(data)
if views_ind is None:
views_ind = np.array([0, new_data.shape[1]])
if new_data.shape[1] > 1:
views_ind = np.array([0, new_data.shape[1] // 2, new_data.shape[1]])
else:
views_ind = np.array([0, new_data.shape[1]])
except Exception as e:
raise ValueError('Reshape your data')
if new_data.ndim < 2 :
......
......@@ -85,7 +85,7 @@ class MKernel(metaclass=ABCMeta):
if not isinstance(X_, MultiModalArray):
try:
X_ = np.asarray(X)
X_ = MultiModalArray(X_)
X_ = MultiModalArray(X_, views_ind)
except Exception as e:
pass
# raise TypeError('Reshape your data')
......
......@@ -104,36 +104,17 @@ class MVML(MKernel, BaseEstimator, ClassifierMixin):
>>> from multimodal.kernels.mvml import MVML
>>> from sklearn.datasets import load_iris
>>> X, y = load_iris(return_X_y=True)
>>> y[y>0] = 1
>>> views_ind = [0, 2, 4] # view 0: sepal data, view 1: petal data
>>> clf = MVML()
clf.get_params()
>>> clf.get_params()
{'eta': 1, 'kernel': 'linear', 'kernel_params': None, 'learn_A': 1, 'learn_w': 0, 'lmbda': 0.1, 'n_loops': 6, 'nystrom_param': 1.0, 'precision': 0.0001}
>>> clf.fit(X, y, views_ind) # doctest: +NORMALIZE_WHITESPACE
MumboClassifier(base_estimator=None, best_view_mode='edge',
n_estimators=50, random_state=0)
MVML(eta=1, kernel='linear', kernel_params=None, learn_A=1, learn_w=0,
lmbda=0.1, n_loops=6, nystrom_param=1.0, precision=0.0001)
>>> print(clf.predict([[ 5., 3., 1., 1.]]))
[1]
>>> views_ind = [[0, 2], [1, 3]] # view 0: length data, view 1: width data
>>> clf = MumboClassifier(random_state=0)
>>> clf.fit(X, y, views_ind) # doctest: +NORMALIZE_WHITESPACE
MumboClassifier(base_estimator=None, best_view_mode='edge',
n_estimators=50, random_state=0)
>>> print(clf.predict([[ 5., 3., 1., 1.]]))
[1]
0
>>> from sklearn.tree import DecisionTreeClassifier
>>> base_estimator = DecisionTreeClassifier(max_depth=2)
>>> clf = MumboClassifier(base_estimator=base_estimator, random_state=0)
>>> clf.fit(X, y, views_ind) # doctest: +NORMALIZE_WHITESPACE
MumboClassifier(base_estimator=DecisionTreeClassifier(class_weight=None,
criterion='gini', max_depth=2, max_features=None,
max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=None,
splitter='best'),
best_view_mode='edge', n_estimators=50, random_state=0)
>>> print(clf.predict([[ 5., 3., 1., 1.]]))
[1]
"""
# r_cond = 10-30
def __init__(self, lmbda=0.1, eta=1, nystrom_param=1.0, kernel="linear",
......@@ -471,8 +452,8 @@ class MVML(MKernel, BaseEstimator, ClassifierMixin):
return pred
else:
pred = np.sign(pred)
pred[pred==-1] = 0
pred = pred.astype(int)
pred = np.where(pred == -1, 0 , pred)
return np.take(self.classes_, pred)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment