Skip to content
Snippets Groups Projects
Commit a2e35c91 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Modded scm_bagging for random search

parent c4bd09cf
Branches
No related tags found
No related merge requests found
...@@ -22,13 +22,13 @@ nb_folds: 5 ...@@ -22,13 +22,13 @@ nb_folds: 5
nb_class: 2 nb_class: 2
classes: classes:
type: ["monoview"] type: ["monoview"]
algos_monoview: ["cb_boost",] algos_monoview: ["scm_bagging",]
algos_multiview: ["group_scm"] algos_multiview: ["group_scm"]
stats_iter: 2 stats_iter: 2
metrics: metrics:
accuracy_score: {} accuracy_score: {}
f1_score: f1_score:
average: 'binary' average: 'binary'
metric_princ: "accuracy_score" metric_princ: "f1_score"
hps_type: "None" hps_type: "Random"
hps_args: {} hps_args: {}
\ No newline at end of file
...@@ -25,6 +25,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): ...@@ -25,6 +25,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
base_boosting_estimators, base_boosting_estimators,
["auto"]] ["auto"]]
self.weird_strings = {"base_estimator": "class_name"} self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config
...@@ -22,6 +22,7 @@ import numbers ...@@ -22,6 +22,7 @@ import numbers
import numpy as np import numpy as np
from six import iteritems from six import iteritems
from warnings import warn from warnings import warn
import logging
MAX_INT = np.iinfo(np.int32).max MAX_INT = np.iinfo(np.int32).max
...@@ -83,9 +84,11 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier ...@@ -83,9 +84,11 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
max_samples=1.0, max_samples=1.0,
max_features=1.0, max_features=1.0,
max_rules=10, max_rules=10,
p_options=[1.0], p_options=[0.316],
model_type="conjunction", model_type="conjunction",
random_state=None): random_state=None):
if isinstance(p_options, float):
p_options = [p_options]
self.n_estimators = n_estimators self.n_estimators = n_estimators
self.max_samples = max_samples self.max_samples = max_samples
self.max_features = max_features self.max_features = max_features
...@@ -95,12 +98,21 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier ...@@ -95,12 +98,21 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
self.random_state = random_state self.random_state = random_state
self.labels_to_binary = {} self.labels_to_binary = {}
self.binary_to_labels = {} self.binary_to_labels = {}
self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "random_state"] self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"]
self.classed_params = [] self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20), self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20),
CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], [random_state]] CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], CustomUniform(), [random_state]]
self.weird_strings = {} self.weird_strings = {}
def set_params(self, p_options=[0.316], **kwargs):
if not isinstance(p_options, list):
p_options = [p_options]
kwargs["p_options"] = p_options
for parameter, value in iteritems(kwargs):
setattr(self, parameter, value)
return self
def p_for_estimators(self): def p_for_estimators(self):
"""Return the value of p for each estimator to fit.""" """Return the value of p for each estimator to fit."""
options_len = len(self.p_options) # number of options options_len = len(self.p_options) # number of options
...@@ -135,10 +147,10 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier ...@@ -135,10 +147,10 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
} }
return hyperparams return hyperparams
def set_params(self, **parameters): # def set_params(self, **parameters):
for parameter, value in iteritems(parameters): # for parameter, value in iteritems(parameters):
setattr(self, parameter, value) # setattr(self, parameter, value)
return self # return self
def labels_conversion(self, labels_list): def labels_conversion(self, labels_list):
l = list(set(labels_list)) l = list(set(labels_list))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment