From 279c4c8ecea20ec450f77d5419604c87f68e278c Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Mon, 6 Mar 2023 09:45:10 -0500
Subject: [PATCH] Pyscm boostable

---
 .../monoview_classifiers/scm.py               | 41 ++++++------
 .../monoview_classifiers/scmboost.py          | 65 +++----------------
 .../multiview_platform/utils/compression.py   |  3 +-
 3 files changed, 30 insertions(+), 79 deletions(-)

diff --git a/summit/multiview_platform/monoview_classifiers/scm.py b/summit/multiview_platform/monoview_classifiers/scm.py
index b3096059..8db54053 100644
--- a/summit/multiview_platform/monoview_classifiers/scm.py
+++ b/summit/multiview_platform/monoview_classifiers/scm.py
@@ -64,13 +64,10 @@ class SCM(scm, BaseMonoviewClassifier):
         self.weird_strings = {}
 
     def fit(self, X, y, tiebreaker=None, iteration_callback=None, sample_weight=None, **fit_params):
-        if sample_weight is not None:
-            new_X, new_y = self.fake_repetitions(X, y, sample_weight, precision=4)
-        else:
-            new_X = X
-            new_y = y
-        self.n_features = new_X.shape[1]
-        scm.fit(self, new_X, new_y, tiebreaker=None, iteration_callback=None, **fit_params)
+        self.n_features = X.shape[1]
+        scm.fit(self, X, y, tiebreaker=tiebreaker,
+                iteration_callback=iteration_callback,
+                sample_weight=sample_weight, **fit_params)
         self.feature_importances_ = np.zeros(self.n_features)
         # sum the rules importances :
         # rules_importances = estim.get_rules_importances() #activate it when pyscm will implement importance
@@ -81,21 +78,21 @@ class SCM(scm, BaseMonoviewClassifier):
         self.feature_importances_ /= np.sum(self.feature_importances_)
         return self
 
-    def fake_repetitions(self, X, y, sample_weight, precision=3):
-        sample_repetitions = (np.round(sample_weight, precision)*10**precision).astype(np.int64)
-        for ind, sample_rep in enumerate(sample_repetitions):
-            if sample_rep==0:
-                sample_repetitions[ind] = 1
-        gcd = np.gcd.reduce(sample_repetitions)
-        sample_repetitions = (sample_repetitions/gcd).astype(np.int64)
-        new_X = np.zeros((X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)), X.shape[1]))
-        new_y = np.zeros(X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)))
-        ind=0
-        for sample_index, (sample_rep, sample, label) in enumerate(zip(sample_repetitions, X, y)):
-            new_X[ind:ind+sample_rep, :] = sample
-            new_y[ind:ind+sample_rep]  =label
-            ind+=sample_rep
-        return new_X, new_y
+    # def fake_repetitions(self, X, y, sample_weight, precision=3):
+    #     sample_repetitions = (np.round(sample_weight, precision)*10**precision).astype(np.int64)
+    #     for ind, sample_rep in enumerate(sample_repetitions):
+    #         if sample_rep==0:
+    #             sample_repetitions[ind] = 1
+    #     gcd = np.gcd.reduce(sample_repetitions)
+    #     sample_repetitions = (sample_repetitions/gcd).astype(np.int64)
+    #     new_X = np.zeros((X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)), X.shape[1]))
+    #     new_y = np.zeros(X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)))
+    #     ind=0
+    #     for sample_index, (sample_rep, sample, label) in enumerate(zip(sample_repetitions, X, y)):
+    #         new_X[ind:ind+sample_rep, :] = sample
+    #         new_y[ind:ind+sample_rep]  =label
+    #         ind+=sample_rep
+    #     return new_X, new_y
 
 
 
diff --git a/summit/multiview_platform/monoview_classifiers/scmboost.py b/summit/multiview_platform/monoview_classifiers/scmboost.py
index 73db0419..2ed86500 100644
--- a/summit/multiview_platform/monoview_classifiers/scmboost.py
+++ b/summit/multiview_platform/monoview_classifiers/scmboost.py
@@ -25,69 +25,24 @@ class SCMboost(AdaBoostClassifier, BaseMonoviewClassifier):
     """
 
     def __init__(self, random_state=None, n_estimators=50,
-                 base_estimator=SCM(p=0.49, max_rules=1, model_type="disjunction"),
-                 base_estimator_config=None, **kwargs):
+                 base_estimator=SCM(p=0.49, max_rules=10, model_type="conjunction"),
+                  **kwargs):
         if "base_estimator__p" in kwargs:
             base_estimator.p = kwargs["base_estimator__p"]
+        if "base_estimator__model_type" in kwargs:
+            base_estimator.model_type = kwargs["base_estimator__model_type"]
         AdaBoostClassifier.__init__(self,
                                     random_state=random_state,
                                     n_estimators=n_estimators,
                                     base_estimator=base_estimator,
                                     algorithm="SAMME",)
-        self.param_names = ["n_estimators", "base_estimator__p"]
+        self.param_names = ["n_estimators", "base_estimator__p", "base_estimator__model_type"]
         self.classed_params = []
-        self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1)]
+        self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1), ["conjunction", "disjunction"]]
         self.weird_strings = {}
-        self.plotted_metric = metrics.zero_one_loss
-        self.plotted_metric_name = "zero_one_loss"
-        self.base_estimator_config = base_estimator_config
-        self.step_predictions = None
 
-    def fit(self, X, y, sample_weight=None):
-        begin = time.time()
-        AdaBoostClassifier.fit(self, X, y)
-        end = time.time()
-        self.train_time = end - begin
-        self.train_shape = X.shape
-        self.base_predictions = np.array(
-            [estim.predict(X) for estim in self.estimators_])
-        self.metrics = np.array([self.plotted_metric.score(pred, y) for pred in
-                                 self.staged_predict(X)])
-        return self
 
-    def predict(self, X):
-        begin = time.time()
-        pred = AdaBoostClassifier.predict(self, X)
-        end = time.time()
-        self.pred_time = end - begin
-        self.step_predictions = np.array(
-            [step_pred for step_pred in self.staged_predict(X)])
-        return pred
-
-    def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
-                           multi_class=False):  # pragma: no cover
-        interpretString = ""
-        # interpretString += self.get_feature_importance(directory,
-        #                                                base_file_name,
-        #                                                feature_ids)
-        # interpretString += "\n\n Estimator error | Estimator weight\n"
-        # interpretString += "\n".join(
-        #     [str(error) + " | " + str(weight / sum(self.estimator_weights_)) for
-        #      error, weight in
-        #      zip(self.estimator_errors_, self.estimator_weights_)])
-        # step_test_metrics = np.array(
-        #     [self.plotted_metric.score(y_test, step_pred) for step_pred in
-        #      self.step_predictions])
-        # get_accuracy_graph(step_test_metrics, "Adaboost",
-        #                    os.path.join(directory,
-        #                                 base_file_name + "test_metrics.png"),
-        #                    self.plotted_metric_name, set="test")
-        # np.savetxt(os.path.join(directory, base_file_name + "test_metrics.csv"),
-        #            step_test_metrics,
-        #            delimiter=',')
-        # np.savetxt(
-        #     os.path.join(directory, base_file_name + "train_metrics.csv"),
-        #     self.metrics, delimiter=',')
-        # np.savetxt(os.path.join(directory, base_file_name + "times.csv"),
-        #            np.array([self.train_time, self.pred_time]), delimiter=',')
-        return interpretString
+    # def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
+    #                        multi_class=False):  # pragma: no cover
+    #     interpretString = ""
+    #     return interpretString
diff --git a/summit/multiview_platform/utils/compression.py b/summit/multiview_platform/utils/compression.py
index 24b0ebc7..0a11b895 100644
--- a/summit/multiview_platform/utils/compression.py
+++ b/summit/multiview_platform/utils/compression.py
@@ -48,6 +48,5 @@ if __name__=="__main__":
     #     for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
     #         print("\t", exp)
     #         explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
-    explore_files(
-        os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", "mage_dset", "debug_started_2022_12_13-10_15_20_th"))
+    explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
     # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
-- 
GitLab