Skip to content
Snippets Groups Projects
Commit 279c4c8e authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Pyscm boostable

parent 0468df2e
No related branches found
No related tags found
No related merge requests found
Pipeline #11514 failed
......@@ -64,13 +64,10 @@ class SCM(scm, BaseMonoviewClassifier):
self.weird_strings = {}
def fit(self, X, y, tiebreaker=None, iteration_callback=None, sample_weight=None, **fit_params):
if sample_weight is not None:
new_X, new_y = self.fake_repetitions(X, y, sample_weight, precision=4)
else:
new_X = X
new_y = y
self.n_features = new_X.shape[1]
scm.fit(self, new_X, new_y, tiebreaker=None, iteration_callback=None, **fit_params)
self.n_features = X.shape[1]
scm.fit(self, X, y, tiebreaker=tiebreaker,
iteration_callback=iteration_callback,
sample_weight=sample_weight, **fit_params)
self.feature_importances_ = np.zeros(self.n_features)
# sum the rules importances :
# rules_importances = estim.get_rules_importances() #activate it when pyscm will implement importance
......@@ -81,21 +78,21 @@ class SCM(scm, BaseMonoviewClassifier):
self.feature_importances_ /= np.sum(self.feature_importances_)
return self
def fake_repetitions(self, X, y, sample_weight, precision=3):
sample_repetitions = (np.round(sample_weight, precision)*10**precision).astype(np.int64)
for ind, sample_rep in enumerate(sample_repetitions):
if sample_rep==0:
sample_repetitions[ind] = 1
gcd = np.gcd.reduce(sample_repetitions)
sample_repetitions = (sample_repetitions/gcd).astype(np.int64)
new_X = np.zeros((X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)), X.shape[1]))
new_y = np.zeros(X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)))
ind=0
for sample_index, (sample_rep, sample, label) in enumerate(zip(sample_repetitions, X, y)):
new_X[ind:ind+sample_rep, :] = sample
new_y[ind:ind+sample_rep] =label
ind+=sample_rep
return new_X, new_y
# def fake_repetitions(self, X, y, sample_weight, precision=3):
# sample_repetitions = (np.round(sample_weight, precision)*10**precision).astype(np.int64)
# for ind, sample_rep in enumerate(sample_repetitions):
# if sample_rep==0:
# sample_repetitions[ind] = 1
# gcd = np.gcd.reduce(sample_repetitions)
# sample_repetitions = (sample_repetitions/gcd).astype(np.int64)
# new_X = np.zeros((X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)), X.shape[1]))
# new_y = np.zeros(X.shape[0]+ int(np.sum(sample_repetitions)-len(sample_repetitions)))
# ind=0
# for sample_index, (sample_rep, sample, label) in enumerate(zip(sample_repetitions, X, y)):
# new_X[ind:ind+sample_rep, :] = sample
# new_y[ind:ind+sample_rep] =label
# ind+=sample_rep
# return new_X, new_y
......
......@@ -25,69 +25,24 @@ class SCMboost(AdaBoostClassifier, BaseMonoviewClassifier):
"""
def __init__(self, random_state=None, n_estimators=50,
base_estimator=SCM(p=0.49, max_rules=1, model_type="disjunction"),
base_estimator_config=None, **kwargs):
base_estimator=SCM(p=0.49, max_rules=10, model_type="conjunction"),
**kwargs):
if "base_estimator__p" in kwargs:
base_estimator.p = kwargs["base_estimator__p"]
if "base_estimator__model_type" in kwargs:
base_estimator.model_type = kwargs["base_estimator__model_type"]
AdaBoostClassifier.__init__(self,
random_state=random_state,
n_estimators=n_estimators,
base_estimator=base_estimator,
algorithm="SAMME",)
self.param_names = ["n_estimators", "base_estimator__p"]
self.param_names = ["n_estimators", "base_estimator__p", "base_estimator__model_type"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1)]
self.distribs = [CustomRandint(low=1, high=100), CustomUniform(loc=0, state=1), ["conjunction", "disjunction"]]
self.weird_strings = {}
self.plotted_metric = metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss"
self.base_estimator_config = base_estimator_config
self.step_predictions = None
def fit(self, X, y, sample_weight=None):
begin = time.time()
AdaBoostClassifier.fit(self, X, y)
end = time.time()
self.train_time = end - begin
self.train_shape = X.shape
self.base_predictions = np.array(
[estim.predict(X) for estim in self.estimators_])
self.metrics = np.array([self.plotted_metric.score(pred, y) for pred in
self.staged_predict(X)])
return self
def predict(self, X):
begin = time.time()
pred = AdaBoostClassifier.predict(self, X)
end = time.time()
self.pred_time = end - begin
self.step_predictions = np.array(
[step_pred for step_pred in self.staged_predict(X)])
return pred
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False): # pragma: no cover
interpretString = ""
# interpretString += self.get_feature_importance(directory,
# base_file_name,
# feature_ids)
# interpretString += "\n\n Estimator error | Estimator weight\n"
# interpretString += "\n".join(
# [str(error) + " | " + str(weight / sum(self.estimator_weights_)) for
# error, weight in
# zip(self.estimator_errors_, self.estimator_weights_)])
# step_test_metrics = np.array(
# [self.plotted_metric.score(y_test, step_pred) for step_pred in
# self.step_predictions])
# get_accuracy_graph(step_test_metrics, "Adaboost",
# os.path.join(directory,
# base_file_name + "test_metrics.png"),
# self.plotted_metric_name, set="test")
# np.savetxt(os.path.join(directory, base_file_name + "test_metrics.csv"),
# step_test_metrics,
# delimiter=',')
# np.savetxt(
# os.path.join(directory, base_file_name + "train_metrics.csv"),
# self.metrics, delimiter=',')
# np.savetxt(os.path.join(directory, base_file_name + "times.csv"),
# np.array([self.train_time, self.pred_time]), delimiter=',')
return interpretString
# def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
# multi_class=False): # pragma: no cover
# interpretString = ""
# return interpretString
......@@ -48,6 +48,5 @@ if __name__=="__main__":
# for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
# print("\t", exp)
# explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
explore_files(
os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", "mage_dset", "debug_started_2022_12_13-10_15_20_th"))
explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
# simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment