From a2e35c91c0e95cfec68ff19be19668258d1679c1 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 21 May 2020 08:10:49 -0400
Subject: [PATCH] Modded scm_bagging for random search

---
 config_files/config_cuisine.yml               |  6 ++---
 .../monoview_classifiers/imbalance_bagging.py |  1 +
 .../monoview_classifiers/scm_bagging.py       | 26 ++++++++++++++-----
 3 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/config_files/config_cuisine.yml b/config_files/config_cuisine.yml
index f4ed64cd..656d2d87 100644
--- a/config_files/config_cuisine.yml
+++ b/config_files/config_cuisine.yml
@@ -22,13 +22,13 @@ nb_folds: 5
 nb_class: 2
 classes:
 type: ["monoview"]
-algos_monoview: ["cb_boost",]
+algos_monoview: ["scm_bagging",]
 algos_multiview: ["group_scm"]
 stats_iter: 2
 metrics:
   accuracy_score: {}
   f1_score:
     average: 'binary'
-metric_princ: "accuracy_score"
-hps_type: "None"
+metric_princ: "f1_score"
+hps_type: "Random"
 hps_args: {}
\ No newline at end of file
diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
index 4037e6ee..c4340420 100644
--- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
+++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py
@@ -25,6 +25,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
                          base_boosting_estimators,
                          ["auto"]]
         self.weird_strings = {"base_estimator": "class_name"}
+        self.base_estimator_config = base_estimator_config
 
 
 
diff --git a/summit/multiview_platform/monoview_classifiers/scm_bagging.py b/summit/multiview_platform/monoview_classifiers/scm_bagging.py
index 2a259a4d..d91598cf 100644
--- a/summit/multiview_platform/monoview_classifiers/scm_bagging.py
+++ b/summit/multiview_platform/monoview_classifiers/scm_bagging.py
@@ -22,6 +22,7 @@ import numbers
 import numpy as np
 from six import iteritems
 from warnings import warn
+import logging
 
 MAX_INT = np.iinfo(np.int32).max
 
@@ -83,9 +84,11 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
                  max_samples=1.0,
                  max_features=1.0,
                  max_rules=10,
-                 p_options=[1.0],
+                 p_options=[0.316],
                  model_type="conjunction",
                  random_state=None):
+        if isinstance(p_options, float):
+            p_options = [p_options]
         self.n_estimators = n_estimators
         self.max_samples = max_samples
         self.max_features = max_features
@@ -95,12 +98,21 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
         self.random_state = random_state
         self.labels_to_binary = {}
         self.binary_to_labels = {}
-        self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "random_state"]
+        self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"]
         self.classed_params = []
         self.distribs = [CustomRandint(low=1, high=300), CustomRandint(low=1, high=20),
-                         CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], [random_state]]
+                         CustomUniform(), CustomUniform(), ["conjunction", "disjunction"], CustomUniform(), [random_state]]
         self.weird_strings = {}
 
+    def set_params(self, p_options=[0.316], **kwargs):
+        if not isinstance(p_options, list):
+            p_options = [p_options]
+        kwargs["p_options"] = p_options
+        for parameter, value in iteritems(kwargs):
+            setattr(self, parameter, value)
+        return self
+
+
     def p_for_estimators(self):
         """Return the value of p for each estimator to fit."""
         options_len = len(self.p_options)  # number of options
@@ -135,10 +147,10 @@ class ScmBaggingClassifier(BaseEnsemble, ClassifierMixin, BaseMonoviewClassifier
         }
         return hyperparams
 
-    def set_params(self, **parameters):
-        for parameter, value in iteritems(parameters):
-            setattr(self, parameter, value)
-        return self
+    # def set_params(self, **parameters):
+    #     for parameter, value in iteritems(parameters):
+    #         setattr(self, parameter, value)
+    #     return self
 
     def labels_conversion(self, labels_list):
         l = list(set(labels_list))
-- 
GitLab