From 4e8bb05a1552b91473a9633f225672926fbc8760 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Mon, 22 May 2023 09:58:22 -0400
Subject: [PATCH] Update for camera-ready

---
 .../monoview_classifiers/gradient_boosting.py |  2 +-
 .../locally_weighted_linear_regression.py     | 54 +++++++++++++++++++
 .../monoview_classifiers/lwlr.py              |  9 ++--
 .../monoview_classifiers/scmboost.py          |  2 +-
 .../monoview_classifiers/svm_rbf.py           |  7 +--
 .../monoview_classifiers/xgboost.py           |  2 +-
 .../multiview_platform/utils/compression.py   | 17 ++++--
 7 files changed, 77 insertions(+), 16 deletions(-)

diff --git a/summit/multiview_platform/monoview_classifiers/gradient_boosting.py b/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
index 57288a25..281a35e9 100644
--- a/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
+++ b/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
@@ -44,7 +44,7 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
                                             )
         self.param_names = ["n_estimators", "max_depth", "loss", "learning_rate"]
         self.classed_params = []
-        self.distribs = [CustomRandint(low=50, high=500),
+        self.distribs = [CustomRandint(low=1, high=300),
                          CustomRandint(low=1, high=10),
                          ['log_loss', 'deviance', 'exponential'],
                          CustomUniform(loc=0, state=1)]
diff --git a/summit/multiview_platform/monoview_classifiers/locally_weighted_linear_regression.py b/summit/multiview_platform/monoview_classifiers/locally_weighted_linear_regression.py
index e69de29b..23265dc4 100644
--- a/summit/multiview_platform/monoview_classifiers/locally_weighted_linear_regression.py
+++ b/summit/multiview_platform/monoview_classifiers/locally_weighted_linear_regression.py
@@ -0,0 +1,54 @@
+import numpy as np
+
+from ..monoview.monoview_utils import BaseMonoviewClassifier
+from summit.multiview_platform.utils.hyper_parameter_search import CustomUniform
+
+classifier_class_name = "LWLR"
+class LWLR(BaseMonoviewClassifier):
+
+    def __init__(self, tau=0.05, reg=0.0001, threshold=1e-6, random_state=42):
+        self.reg = reg
+        self.threshold = threshold
+        self.tau = tau
+        self.random_state = random_state
+        self.param_names = ["tau", 'reg', "threshold"]
+        self.distribs = [CustomUniform(loc=1e-2, state=1),
+                         CustomUniform(loc=1e-6, state=1e-2),
+                         CustomUniform(loc=1e-8, state=1e-4)]
+        self.weird_strings={}
+        self.classed_params=[]
+
+    def weights(self, x_train, x):
+        sq_diff = (x_train - x) ** 2
+        norm_sq = sq_diff.sum(axis=1)
+        return np.ravel(np.exp(- norm_sq / (2 * self.tau ** 2)))
+
+    def logistic(self, x_train):
+        return np.ravel(1 / (1 + np.exp(-x_train.dot(self.theta))))
+
+    def fit(self, X, y, **fit_params):
+        self.X = X
+        self.y = y
+
+    def train(self, x):
+        self.w = self.weights(self.X, x)
+        self.theta = np.zeros(self.X.shape[1])
+        gradient = np.ones(self.X.shape[1]) * np.inf
+        while np.linalg.norm(gradient) > self.threshold:
+            # compute gradient
+            h = self.logistic(self.X)
+            gradient = self.X.T.dot(
+                self.w * (np.ravel(self.y) - h)) - self.reg * self.theta
+            # Compute Hessian
+            D = np.diag(-(self.w * h * (1 - h)))
+            H = self.X.T.dot(D).dot(self.X) - self.reg * np.identity(
+                self.X.shape[1])
+            # weight update
+            self.theta = self.theta - np.linalg.inv(H).dot(gradient)
+
+    def predict(self, X):
+        preds = []
+        for x in X:
+            self.train(x)
+            preds.append(np.array(self.logistic(X) > 0.5).astype(int)[0])
+        return np.array(preds)
diff --git a/summit/multiview_platform/monoview_classifiers/lwlr.py b/summit/multiview_platform/monoview_classifiers/lwlr.py
index 48db51d3..2387ed17 100644
--- a/summit/multiview_platform/monoview_classifiers/lwlr.py
+++ b/summit/multiview_platform/monoview_classifiers/lwlr.py
@@ -1,5 +1,4 @@
-from summit.multiview_platform.monoview_classifiers.additions.SVCClassifier import \
-    SVCClassifier
+from learners.algorithms.lwlr import LWLRLearner
 
 from ..monoview.monoview_utils import BaseMonoviewClassifier
 from summit.multiview_platform.utils.hyper_parameter_search import CustomUniform
@@ -11,15 +10,15 @@ __status__ = "Prototype"  # Production, Development, Prototype
 classifier_class_name = "SVMRBF"
 
 
-class SVMRBF(SVCClassifier, BaseMonoviewClassifier):
+class LWLRClassifier(LWLRLearner, BaseMonoviewClassifier):
     """
     This class is an adaptation of scikit-learn's `SVC <https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html>`_
 
     Here, it is the RBF kernel version
     """
 
-    def __init__(self, random_state=None, gamma="auto", C=1.0, **kwargs):
-        SVCClassifier.__init__(self,
+    def __init__(self, random_state=None, sigma="auto", nn=1.0, **kwargs):
+        LWLRLearner.__init__(self,
                                C=C,
                                kernel='rbf',
                                gamma=gamma,
diff --git a/summit/multiview_platform/monoview_classifiers/scmboost.py b/summit/multiview_platform/monoview_classifiers/scmboost.py
index 89f1e4b7..070c1207 100644
--- a/summit/multiview_platform/monoview_classifiers/scmboost.py
+++ b/summit/multiview_platform/monoview_classifiers/scmboost.py
@@ -40,7 +40,7 @@ class SCMboost(AdaBoostClassifier, BaseMonoviewClassifier):
                                     algorithm="SAMME",)
         self.param_names = ["n_estimators", "base_estimator__p", "base_estimator__model_type", "base_estimator__max_rules"]
         self.classed_params = []
-        self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1), ["conjunction", "disjunction"], CustomRandint(low=1, high=20)]
+        self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1), ["conjunction", "disjunction"], CustomRandint(low=1, high=5)]
         self.weird_strings = {}
 
 
diff --git a/summit/multiview_platform/monoview_classifiers/svm_rbf.py b/summit/multiview_platform/monoview_classifiers/svm_rbf.py
index 504ef095..48db51d3 100644
--- a/summit/multiview_platform/monoview_classifiers/svm_rbf.py
+++ b/summit/multiview_platform/monoview_classifiers/svm_rbf.py
@@ -18,11 +18,12 @@ class SVMRBF(SVCClassifier, BaseMonoviewClassifier):
     Here, it is the RBF kernel version
     """
 
-    def __init__(self, random_state=None, C=1.0, **kwargs):
+    def __init__(self, random_state=None, gamma="auto", C=1.0, **kwargs):
         SVCClassifier.__init__(self,
                                C=C,
                                kernel='rbf',
+                               gamma=gamma,
                                random_state=random_state
                                )
-        self.param_names = ["C", "random_state"]
-        self.distribs = [CustomUniform(loc=0, state=1), [random_state]]
+        self.param_names = ["C", 'gamma', "random_state"]
+        self.distribs = [CustomUniform(loc=1e-3, state=1e3), CustomUniform(loc=1e-1, state=1e1),  [random_state]]
diff --git a/summit/multiview_platform/monoview_classifiers/xgboost.py b/summit/multiview_platform/monoview_classifiers/xgboost.py
index 2b44f211..280b1150 100644
--- a/summit/multiview_platform/monoview_classifiers/xgboost.py
+++ b/summit/multiview_platform/monoview_classifiers/xgboost.py
@@ -32,7 +32,7 @@ class XGB(XGBClassifier, BaseMonoviewClassifier):
                                random_state=random_state)
         self.param_names = ["n_estimators", "learning_rate", "max_depth", "objective"]
         self.classed_params = []
-        self.distribs = [CustomRandint(low=10, high=500),
+        self.distribs = [CustomRandint(low=1, high=300),
                          CustomUniform(),
                          CustomRandint(low=1, high=10),
                          ['binary:logistic', 'binary:hinge', ],]
diff --git a/summit/multiview_platform/utils/compression.py b/summit/multiview_platform/utils/compression.py
index 6fb7cb49..a2298f27 100644
--- a/summit/multiview_platform/utils/compression.py
+++ b/summit/multiview_platform/utils/compression.py
@@ -43,14 +43,21 @@ def remove_compressed(exp_path):
 
 
 if __name__=="__main__":
-    # for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
+    for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
+        if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir)):
+            for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
+                print("\t", exp)
+                if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
+                    explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
+    plif = dict()
+    # for dir in os.listdir("/home/baptiste/Documents/Clouded/short_projects/SCMBoost/results"):
     #     print(dir)
-    #     for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
+    #     for exp in os.listdir((os.path.join("/home/baptiste/Documents/Clouded/short_projects/SCMBoost/results", dir))):
     #         print("\t", exp)
-    #         if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
-    #             explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
+    #         if os.path.isdir(os.path.join("/home/baptiste/Documents/Clouded/short_projects/SCMBoost/results", dir, exp)):
+    #             explore_files(os.path.join("/home/baptiste/Documents/Clouded/short_projects/SCMBoost/results", dir, exp))
     # # explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
-    explore_files("/home/baptiste/Documents/Gitwork/summit/results/clinical/debug_started_2023_04_05-08_23_00_bal_acc")
+    # explore_files("/home/baptiste/Documents/Gitwork/summit/results/clinical/debug_started_2023_04_05-08_23_00_bal_acc")
     # explore_files(
     #     "/home/baptiste/Documents/Gitwork/summit/results/lives_thesis_EMF/debug_started_2023_03_24-10_02_21_thesis")
     # # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
-- 
GitLab