Skip to content
Snippets Groups Projects
Commit 74a5ac85 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Merge branch 'master' of https://github.com/multi-learn/summit

parents 6fdc4ddb e133cc5f
Branches
No related tags found
No related merge requests found
from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from ..monoview.monoview_utils import BaseMonoviewClassifier
from ..utils.base import base_boosting_estimators
from ..utils.hyper_parameter_search import CustomRandint, CustomUniform
classifier_class_name = "ImbalanceBagging"
class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier",
n_estimators=10, sampling_strategy="auto",
replacement=False, base_estimator_config=None):
base_estimator = self.get_base_estimator(base_estimator,
base_estimator_config)
super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator,
n_estimators=n_estimators,
sampling_strategy=sampling_strategy,
replacement=replacement)
self.param_names = ["n_estimators", "base_estimator", "sampling_strategy",]
self.classed_params = ["base_estimator"]
self.distribs = [CustomRandint(low=1, high=50),
base_boosting_estimators,
["auto"]]
self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config
from randomscm.randomscm import RandomScmClassifier
from ..monoview.monoview_utils import BaseMonoviewClassifier
from summit.multiview_platform.utils.hyper_parameter_search import CustomUniform, CustomRandint
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
classifier_class_name = "ScmBagging"
import numpy as np
from six import iteritems
MAX_INT = np.iinfo(np.int32).max
class ScmBagging(RandomScmClassifier, BaseMonoviewClassifier):
"""A Bagging classifier. for SetCoveringMachineClassifier()
The base estimators are built on subsets of both samples
and features.
Parameters
----------
n_estimators : int, default=10
The number of base estimators in the ensemble.
max_samples : int or float, default=1.0
The number of samples to draw from X to train each base estimator with
replacement.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
max_features : int or float, default=1.0
The number of features to draw from X to train each base estimator (
without replacement.
- If int, then draw `max_features` features.
- If float, then draw `max_features * X.shape[1]` features.
p_options : list of float with len =< n_estimators, default=[1.0]
The estimators will be fitted with values of p found in p_options
let k be k = n_estimators/len(p_options),
the k first estimators will have p=p_options[0],
the next k estimators will have p=p_options[1] and so on...
random_state : int or RandomState, default=None
Controls the random resampling of the original dataset
(sample wise and feature wise).
If the base estimator accepts a `random_state` attribute, a different
seed is generated for each instance in the ensemble.
Pass an int for reproducible output across multiple function calls.
See :term:`Glossary <random_state>`.
Attributes
----------
n_features_ : int
The number of features when :meth:`fit` is performed.
estimators_ : list of estimators
The collection of fitted base estimators.
estim_features : list of arrays
The subset of drawn features for each base estimator.
Examples
--------
>>> @TODO
References
----------
.. [1] L. Breiman, "Pasting small votes for classification in large
databases and on-line", Machine Learning, 36(1), 85-103, 1999.
.. [2] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine
Learning and Knowledge Discovery in Databases, 346-361, 2012.
"""
def __init__(self,
n_estimators=100,
max_samples=0.5,
max_features=0.5,
max_rules=10,
p_options=[1.0],
model_type="conjunction",
random_state=None):
if isinstance(p_options, float):
p_options = [p_options]
RandomScmClassifier.__init__(self, n_estimators=n_estimators,
max_samples=max_samples,
max_features=max_features,
max_rules=max_rules,
p_options=p_options,
model_type=model_type,
random_state=random_state)
self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20),
CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], CustomUniform(), [random_state]]
self.weird_strings = {}
def set_params(self, p_options=[0.316], **kwargs):
if not isinstance(p_options, list):
p_options = [p_options]
kwargs["p_options"] = p_options
for parameter, value in iteritems(kwargs):
setattr(self, parameter, value)
return self
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False):
self.features_importance()
interpret_string = self.get_feature_importance(directory,
base_file_name,
feature_ids)
return interpret_string
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment