Skip to content
Snippets Groups Projects
Commit a2e35c91 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Modded scm_bagging for random search

parent c4bd09cf
No related branches found
No related tags found
No related merge requests found
......@@ -22,13 +22,13 @@ nb_folds: 5
nb_class: 2
classes:
type: ["monoview"]
algos_monoview: ["cb_boost",]
algos_monoview: ["scm_bagging",]
algos_multiview: ["group_scm"]
stats_iter: 2
metrics:
accuracy_score: {}
f1_score:
average: 'binary'
metric_princ: "accuracy_score"
hps_type: "None"
metric_princ: "f1_score"
hps_type: "Random"
hps_args: {}
\ No newline at end of file
......@@ -25,6 +25,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
base_boosting_estimators,
["auto"]]
self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config
......@@ -22,6 +22,7 @@ import numbers
import numpy as np
from six import iteritems
from warnings import warn
import logging
MAX_INT = np.iinfo(np.int32).max
......@@ -83,9 +84,11 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
max_samples=1.0,
max_features=1.0,
max_rules=10,
p_options=[1.0],
p_options=[0.316],
model_type="conjunction",
random_state=None):
if isinstance(p_options, float):
p_options = [p_options]
self.n_estimators = n_estimators
self.max_samples = max_samples
self.max_features = max_features
......@@ -95,12 +98,21 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
self.random_state = random_state
self.labels_to_binary = {}
self.binary_to_labels = {}
self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "random_state"]
self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20),
CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], [random_state]]
CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], CustomUniform(), [random_state]]
self.weird_strings = {}
def set_params(self, p_options=[0.316], **kwargs):
if not isinstance(p_options, list):
p_options = [p_options]
kwargs["p_options"] = p_options
for parameter, value in iteritems(kwargs):
setattr(self, parameter, value)
return self
def p_for_estimators(self):
"""Return the value of p for each estimator to fit."""
options_len = len(self.p_options) # number of options
......@@ -135,10 +147,10 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
}
return hyperparams
def set_params(self, **parameters):
for parameter, value in iteritems(parameters):
setattr(self, parameter, value)
return self
# def set_params(self, **parameters):
# for parameter, value in iteritems(parameters):
# setattr(self, parameter, value)
# return self
def labels_conversion(self, labels_list):
l = list(set(labels_list))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment