Skip to content
Snippets Groups Projects
Select Git revision
  • aed430d4d8397d506028ca405df7dd79c43404f6
  • master default protected
2 results

VGG19Cifar10CovAbs.py

Blame
  • sgd.py 1.33 KiB
    from sklearn.linear_model import SGDClassifier
    
    from ..monoview.monoview_utils import BaseMonoviewClassifier
    from summit.multiview_platform.utils.hyper_parameter_search import CustomUniform
    
    # Author-Info
    __author__ = "Baptiste Bauvin"
    __status__ = "Prototype"  # Production, Development, Prototype
    
    classifier_class_name = "SGD"
    
    
    class SGD(SGDClassifier, BaseMonoviewClassifier):
        """
        This class is an adaptation of scikit-learn's `SGDClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html>`_
    
    
        """
    
        def __init__(self, random_state=None, loss='hinge',
                     penalty='l2', alpha=0.0001, max_iter=5, tol=None, **kwargs):
            SGDClassifier.__init__(self,
                                   loss=loss,
                                   penalty=penalty,
                                   alpha=alpha,
                                   max_iter=max_iter,
                                   tol=tol,
                                   random_state=random_state
                                   )
            self.param_names = ["loss", "penalty", "alpha", "random_state"]
            self.classed_params = []
            self.distribs = [['log', 'modified_huber'],
                             ["l1", "l2", "elasticnet"],
                             CustomUniform(loc=0, state=1), [random_state]]
            self.weird_strings = {}