Skip to content
Snippets Groups Projects
Commit 7baef99d authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

tried to do the right hp search for earlyfusion

parent 5862a1cd
Branches
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from .. import multiview_classifiers
from .. import monoview_classifiers
......@@ -92,3 +93,29 @@ def get_train_views_indices(dataset, train_indices, view_indices,):
if train_indices is None:
train_indices = range(dataset["Metadata"].attrs["datasetLength"])
return train_indices, view_indices
class ConfigGenerator():
def __init__(self):
self.distribs = {}
for name in dir(monoview_classifiers):
if not name.startswith("__"):
module = getattr(monoview_classifiers, name)
classifier_class = getattr(module,
module.classifier_class_name)()
self.distribs[name] = dict((param_name, param_distrib)
for param_name, param_distrib in
zip(classifier_class().param_names,
classifier_class().distribs))
def rvs(self, random_state=None):
config_sample = {}
for classifier_name, classifier_config in self.distribs.items():
for param_name, param_distrib in classifier_config.items():
if hasattr(param_distrib, "rvs"):
config_sample[classifier_name][param_name]=param_distrib.rvs(random_state=random_state)
else:
config_sample[classifier_name][
param_name] = param_distrib[random_state.randint(len(param_distrib))]
return config_sample
......@@ -2,7 +2,7 @@ import numpy as np
import pkgutil
from ..utils.dataset import getV
from ..multiview.multiview_utils import BaseMultiviewClassifier, get_train_views_indices
from ..multiview.multiview_utils import BaseMultiviewClassifier, get_train_views_indices, ConfigGenerator
from .. import monoview_classifiers
classifier_class_name = "WeightedLinearEarlyFusion"
......@@ -24,19 +24,27 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier):
self.monoview_classifier = monoview_classifier_class(random_state=random_state,
**monoview_classifier_config)
else:
self.monoview_classifier = monoview_classifier
self.monoview_classifier = monoview_classifier(monoview_classifier_config)
self.short_name = "early fusion "+self.monoview_classifier.__class__.__name__
self.param_names = ["monoview_classifier","random_state"]
self.param_names = ["monoview_classifier","random_state", "monoview_classifier_config"]
classifier_classes = []
for name in dir(monoview_classifiers):
if not name.startswith("__"):
module = getattr(monoview_classifiers, name)
classifier_class = getattr(module, module.classifier_class_name)()
classifier_class = getattr(module, module.classifier_class_name)
classifier_classes.append(classifier_class)
self.distribs = [classifier_classes, [self.random_state]]
self.distribs = [classifier_classes, [self.random_state], ConfigGenerator()]
self.classed_params = ["monoview_classifier"]
self.weird_strings={"monoview_classifier":["class_name", "config"]}
def set_params(self, monoview_classifier=None, monoview_classifier_config=None, **params):
monoview_classifier_name = monoview_classifier.__module__
self.monoview_classifier = monoview_classifier()
self.set_monoview_classifier_config(monoview_classifier_name,
monoview_classifier_config)
def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, X = self.transform_data_to_monoview(X, train_indices, view_indices)
self.monoview_classifier.fit(X, y[train_indices])
......@@ -70,6 +78,12 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier):
, axis=1)
return monoview_data
def set_monoview_classifier_config(self, monoview_classifier_name, monoview_classifier_config):
if monoview_classifier_name in monoview_classifier_config:
self.monoview_classifier.set_params(monoview_classifier_config[monoview_classifier_name])
else:
self.monoview_classifier.set_params(monoview_classifier_config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment