Skip to content
Snippets Groups Projects
Commit 9b7f9916 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Debugged sparsity scm from placeholder

parent 83d4f7cd
No related branches found
No related tags found
No related merge requests found
...@@ -40,6 +40,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -40,6 +40,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
...@@ -57,6 +58,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -57,6 +58,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
......
...@@ -42,6 +42,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -42,6 +42,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
...@@ -59,6 +60,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -59,6 +60,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
......
...@@ -16,7 +16,7 @@ __status__ = "Prototype" # Production, Development, Prototype ...@@ -16,7 +16,7 @@ __status__ = "Prototype" # Production, Development, Prototype
class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
def __init__(self, random_state=None, model_type="conjunction", def __init__(self, random_state=None, model_type="disjunction",
max_rules=10, p=0.1, n_stumps=1, self_complemented=True, **kwargs): max_rules=10, p=0.1, n_stumps=1, self_complemented=True, **kwargs):
self.scm_estimators = [scm( self.scm_estimators = [scm(
random_state=random_state, random_state=random_state,
...@@ -49,6 +49,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): ...@@ -49,6 +49,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
...@@ -59,7 +60,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): ...@@ -59,7 +60,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
end = time.time() end = time.time()
self.times = np.array([end-beg, 0]) self.times = np.array([end-beg, 0])
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(X)) for scm_estimator in self.scm_estimators] self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(place_holder)) for scm_estimator in self.scm_estimators]
return self.scm_estimators[-1] return self.scm_estimators[-1]
def predict(self, X): def predict(self, X):
...@@ -71,6 +72,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier): ...@@ -71,6 +72,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
......
...@@ -49,6 +49,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier): ...@@ -49,6 +49,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
...@@ -59,7 +60,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier): ...@@ -59,7 +60,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params) scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
end = time.time() end = time.time()
self.times = np.array([end-beg, 0]) self.times = np.array([end-beg, 0])
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(X)) for scm_estimator in self.scm_estimators] self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(place_holder)) for scm_estimator in self.scm_estimators]
return self.scm_estimators[-1] return self.scm_estimators[-1]
def predict(self, X): def predict(self, X):
...@@ -71,6 +72,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier): ...@@ -71,6 +72,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv" file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files: while file_name in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else: else:
file_name = "pregen_x"+str(a)+".csv" file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',') np.savetxt(file_name, pregen_X, delimiter=',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment