Skip to content
Snippets Groups Projects
Commit 3d787502 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added fakeestimator

parent 9f5287e7
No related branches found
No related tags found
No related merge requests found
Pipeline #3925 failed
from metriclearning.lpMKL import MKL
from ..multiview.multiview_utils import BaseMultiviewClassifier, get_examples_views_indices
from ..multiview.multiview_utils import BaseMultiviewClassifier, FakeEstimator
from .additions.kernel_learning import KernelClassifier, KernelConfigGenerator, KernelGenerator
from ..utils.hyper_parameter_search import CustomUniform, CustomRandint
......@@ -29,7 +29,7 @@ class LPNormMKL(KernelClassifier, MKL):
try:
self.init_kernels(nb_view=len(formatted_X))
except:
return FakeClassifier()
return FakeEstimator()
return super(LPNormMKL, self).fit(formatted_X, y[train_indices])
......
from metriclearning.mvml import MVML
from ..multiview.multiview_utils import BaseMultiviewClassifier, get_examples_views_indices
from ..multiview.multiview_utils import BaseMultiviewClassifier, FakeEstimator
from .additions.kernel_learning import KernelClassifier, KernelConfigGenerator, KernelGenerator
from ..utils.hyper_parameter_search import CustomUniform, CustomRandint
......@@ -36,7 +36,10 @@ class MVMLClassifier(KernelClassifier, MVML):
def fit(self, X, y, train_indices=None, view_indices=None):
formatted_X, train_indices = self.format_X(X, train_indices, view_indices)
try:
self.init_kernels(nb_view=len(formatted_X))
except:
return FakeEstimator()
return super(MVMLClassifier, self).fit(formatted_X, y[train_indices])
def predict(self, X, example_indices=None, view_indices=None):
......
......@@ -197,7 +197,7 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
for fold_idx, (train_indices, test_indices) in enumerate(folds):
current_estimator = clone(base_estimator)
current_estimator.set_params(**candidate_param)
current_estimator.fit(X, y,
current_estimator = current_estimator.fit(X, y,
train_indices=self.available_indices[train_indices],
view_indices=self.view_indices)
test_prediction = current_estimator.predict(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment