diff --git a/config_files/config_cuisine.yml b/config_files/config_cuisine.yml index f4ed64cd8621d1fb98ee0e49b57051f007b163e9..656d2d87e2f506e1687a057dd68a25bc3a198f9c 100644 --- a/config_files/config_cuisine.yml +++ b/config_files/config_cuisine.yml @@ -22,13 +22,13 @@ nb_folds: 5 nb_class: 2 classes: type: ["monoview"] -algos_monoview: ["cb_boost",] +algos_monoview: ["scm_bagging",] algos_multiview: ["group_scm"] stats_iter: 2 metrics: accuracy_score: {} f1_score: average: 'binary' -metric_princ: "accuracy_score" -hps_type: "None" +metric_princ: "f1_score" +hps_type: "Random" hps_args: {} \ No newline at end of file diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py index 4037e6ee5c5ac39f40662e0b083f40d029deb172..c434042044e96cd9cb5947b70aee3c4ace77647f 100644 --- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -25,6 +25,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): base_boosting_estimators, ["auto"]] self.weird_strings = {"base_estimator": "class_name"} + self.base_estimator_config = base_estimator_config diff --git a/summit/multiview_platform/monoview_classifiers/scm_bagging.py b/summit/multiview_platform/monoview_classifiers/scm_bagging.py index 2a259a4d6f01d3ab97979d3ab59c03ab5dc81f86..d91598cf63956de55e85a6f68deb9e2f07d6059c 100644 --- a/summit/multiview_platform/monoview_classifiers/scm_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/scm_bagging.py @@ -22,6 +22,7 @@ import numbers import numpy as np from six import iteritems from warnings import warn +import logging MAX_INT = np.iinfo(np.int32).max @@ -83,9 +84,11 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier max_samples=1.0, max_features=1.0, max_rules=10, - p_options=[1.0], + p_options=[0.316], model_type="conjunction", random_state=None): + if isinstance(p_options, float): + p_options = [p_options] self.n_estimators = n_estimators self.max_samples = max_samples self.max_features = max_features @@ -95,12 +98,21 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier self.random_state = random_state self.labels_to_binary = {} self.binary_to_labels = {} - self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "random_state"] + self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"] self.classed_params = [] self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20), - CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], [random_state]] + CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], CustomUniform(), [random_state]] self.weird_strings = {} + def set_params(self, p_options=[0.316], **kwargs): + if not isinstance(p_options, list): + p_options = [p_options] + kwargs["p_options"] = p_options + for parameter, value in iteritems(kwargs): + setattr(self, parameter, value) + return self + + def p_for_estimators(self): """Return the value of p for each estimator to fit.""" options_len = len(self.p_options) # number of options @@ -135,10 +147,10 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier } return hyperparams - def set_params(self, **parameters): - for parameter, value in iteritems(parameters): - setattr(self, parameter, value) - return self + # def set_params(self, **parameters): + # for parameter, value in iteritems(parameters): + # setattr(self, parameter, value) + # return self def labels_conversion(self, labels_list): l = list(set(labels_list))