Skip to content
Snippets Groups Projects
Commit d937fa6e authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Expes samba

parent 0705f277
No related branches found
No related tags found
No related merge requests found
Pipeline #11677 failed
import numpy as np
from sklearn.linear_model import Lasso as LassoSK
from sklearn.preprocessing import RobustScaler
from ..monoview.monoview_utils import BaseMonoviewClassifier
from summit.multiview_platform.utils.hyper_parameter_search import CustomUniform, CustomRandint
......@@ -19,28 +20,65 @@ class Lasso(LassoSK, BaseMonoviewClassifier):
"""
def __init__(self, random_state=None, alpha=1.0,
max_iter=10, warm_start=False, **kwargs):
max_iter=10, warm_start=False, scale=False, **kwargs):
LassoSK.__init__(self,
alpha=alpha,
max_iter=max_iter,
warm_start=warm_start,
random_state=random_state
)
self.param_names = ["max_iter", "alpha", "random_state"]
self.scale = scale
self.param_names = ["max_iter", "alpha", "scale", "random_state"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300),
CustomUniform(), [random_state]]
CustomUniform(), [True, False], [random_state]]
self.weird_strings = {}
def fit(self, X, y, check_input=True):
neg_y = np.copy(y)
neg_y[np.where(neg_y == 0)] = -1
if self.scale:
self.scaler = RobustScaler()
X = self.scaler.fit_transform(X, neg_y)
LassoSK.fit(self, X, neg_y)
# self.feature_importances_ = self.coef_/np.sum(self.coef_)
if np.sum(self.coef_)!=0:
self.feature_importances_ = np.abs(self.coef_/np.sum(self.coef_))
else:
self.feature_importances_ = np.abs(self.coef_)
return self
def predict(self, X):
if self.scale:
X = self.scaler.transform(X)
prediction = LassoSK.predict(self, X)
signed = np.sign(prediction)
signed[np.where(signed == -1)] = 0
return signed
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False): # pragma: no cover
interpretString = ""
interpretString += self.get_feature_importance(directory,
base_file_name,
feature_ids)
# interpretString += "\n\n Estimator error | Estimator weight\n"
# interpretString += "\n".join(
# [str(error) + " | " + str(weight / sum(self.estimator_weights_)) for
# error, weight in
# zip(self.estimator_errors_, self.estimator_weights_)])
# step_test_metrics = np.array(
# [self.plotted_metric.score(y_test, step_pred) for step_pred in
# self.step_predictions])
# get_accuracy_graph(step_test_metrics, "Adaboost",
# os.path.join(directory,
# base_file_name + "test_metrics.png"),
# self.plotted_metric_name, set="test")
# np.savetxt(os.path.join(directory, base_file_name + "test_metrics.csv"),
# step_test_metrics,
# delimiter=',')
# np.savetxt(
# os.path.join(directory, base_file_name + "train_metrics.csv"),
# self.metrics, delimiter=',')
# np.savetxt(os.path.join(directory, base_file_name + "times.csv"),
# np.array([self.train_time, self.pred_time]), delimiter=',')
return interpretString
\ No newline at end of file
import os
import time
import numpy as np
from sklearn.ensemble import AdaBoostClassifier
from sklearn.linear_model import RidgeClassifier, LogisticRegression
from lineartree import LinearTreeClassifier
from ..monoview.monoview_utils import BaseMonoviewClassifier
from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
classifier_class_name = "LinearLeafsTree"
class LinearLeafsTree(LinearTreeClassifier, BaseMonoviewClassifier):
"""
This class is an adaptation of scikit-learn's `AdaBoostClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html#sklearn.ensemble.AdaBoostClassifier>`_
"""
def __init__(self, base_estimator=None,
base_estimator_config=None, max_depth=None, **kwargs):
base_estimator = BaseMonoviewClassifier.get_base_estimator(self,
base_estimator,
base_estimator_config)
LinearTreeClassifier.__init__(self,
base_estimator=base_estimator,
max_depth=max_depth
)
self.param_names = [ "base_estimator", "max_depth"]
self.classed_params = ["base_estimator"]
self.distribs = [[LogisticRegression(), RidgeClassifier()], CustomRandint(low=1, high=20)]
self.weird_strings = {"base_estimator": "class_name"}
self.base_estimator_config = base_estimator_config
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False): # pragma: no cover
interpretString = ""
return interpretString
from SamBA.samba import NeighborHoodClassifier, ExpTrainWeighting
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from SamBA.relevances import *
from SamBA.distances import *
from sklearn.preprocessing import RobustScaler
from ..monoview.monoview_utils import BaseMonoviewClassifier
from ..utils.hyper_parameter_search import CustomRandint, CustomUniform
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
# class Decis
classifier_class_name = "SamBADTClf"
class SamBADTClf(NeighborHoodClassifier, BaseMonoviewClassifier):
def __init__(self, base_estimator=DecisionTreeClassifier(max_depth=3,
splitter='best',
criterion='gini'),
n_estimators=2,
estimator_params=tuple(),
relevance=ExpRelevance(),
distance=EuclidianDist(),
difficulty=ExpTrainWeighting(),
keep_selected_features=True,
normalizer=RobustScaler(),
b=2, a=0.01,
pred_train=False,
forced_diversity=False,
normalize_dists=False,
class_weight="balanced",
**kwargs):
"""
Parameters
----------
random_state
model_type
max_rules
p
kwargs
"""
super(SamBADTClf, self).__init__(base_estimator=base_estimator,
n_estimators=n_estimators,
estimator_params=estimator_params,
relevance=relevance,
distance=distance,
difficulty=difficulty,
keep_selected_features=keep_selected_features,
normalizer=normalizer,
forced_diversity=forced_diversity,
b=b, a=a, pred_train=pred_train,
normalize_dists=normalize_dists,
class_weight=class_weight)
self.param_names = ["n_estimators",
"relevance",
"distance",
"difficulty", "b", "pred_train", "normalizer",
"normalize_dists", "a", "class_weight",]
self.distribs = [CustomRandint(low=1, high=70),
[ExpRelevance()],
[EuclidianDist(), PolarDist(), ExpEuclidianDist(), Jaccard()],
[ExpTrainWeighting()],
CustomUniform(0.1, 6,),
[True, False],
[RobustScaler()],
[True], CustomRandint(0, 10, 'e-'),
["balanced", None],
]
self.classed_params = []
self.weird_strings = {}
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False):
interpret_string = self.get_feature_importance(directory, base_file_name, feature_ids)
return interpret_string
import os
import time
import numpy as np
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier
from .. import metrics
from ..monoview.monoview_utils import BaseMonoviewClassifier, get_accuracy_graph
from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint, CustomUniform
# Author-Info
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
classifier_class_name = "XGB"
class XGB(XGBClassifier, BaseMonoviewClassifier):
"""
This class is an adaptation of scikit-learn's `GradientBoostingClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html>`_
"""
def __init__(self, random_state=None, learning_rate=0.3, max_depth=1.0,
n_estimators=100, objective='binary:logistic',
**kwargs):
XGBClassifier.__init__(self, learning_rate=learning_rate, max_depth=max_depth,
n_estimators=n_estimators, objective=objective,
random_state=random_state)
self.param_names = ["n_estimators", "learning_rate", "max_depth", "objective"]
self.classed_params = []
self.distribs = [CustomRandint(low=10, high=500),
CustomUniform(),
CustomRandint(low=1, high=10),
['binary:logistic', 'binary:hinge', ],]
self.weird_strings = {}
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False):
interpretString = ""
if multi_class:
return interpretString
else:
interpretString += self.get_feature_importance(directory,
base_file_name,
feature_ids)
# step_test_metrics = np.array(
# [self.plotted_metric.score(y_test, step_pred) for step_pred in
# self.step_predictions])
# get_accuracy_graph(step_test_metrics, "AdaboostClassic",
# directory + "test_metrics.png",
# self.plotted_metric_name, set="test")
# get_accuracy_graph(self.metrics, "AdaboostClassic",
# directory + "metrics.png",
# self.plotted_metric_name)
# np.savetxt(
# os.path.join(directory, base_file_name + "test_metrics.csv"),
# step_test_metrics,
# delimiter=',')
# np.savetxt(
# os.path.join(directory, base_file_name + "train_metrics.csv"),
# self.metrics,
# delimiter=',')
# np.savetxt(os.path.join(directory, base_file_name + "times.csv"),
# np.array([self.train_time, self.pred_time]),
# delimiter=',')
return interpretString
......@@ -43,14 +43,14 @@ def remove_compressed(exp_path):
if __name__=="__main__":
for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
print(dir)
for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
print("\t", exp)
if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
# for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
# print(dir)
# for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
# print("\t", exp)
# if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
# explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
# # explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
# explore_files("/home/baptiste/Documents/Gitwork/summit/results/tnbc_mazid/debug_started_2023_03_24-11_27_46_thesis")
explore_files("/home/baptiste/Documents/Gitwork/summit/results/clinical/debug_started_2023_04_05-08_23_00_bal_acc")
# explore_files(
# "/home/baptiste/Documents/Gitwork/summit/results/lives_thesis_EMF/debug_started_2023_03_24-10_02_21_thesis")
# # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment