Skip to content
Snippets Groups Projects
Commit 9b7f9916 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Debugged sparsity scm from placeholder

parent 83d4f7cd
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......@@ -57,6 +58,7 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......
......@@ -42,6 +42,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......@@ -59,6 +60,7 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......
......@@ -16,7 +16,7 @@ __status__ = "Prototype" # Production, Development, Prototype
class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
def __init__(self, random_state=None, model_type="conjunction",
def __init__(self, random_state=None, model_type="disjunction",
max_rules=10, p=0.1, n_stumps=1, self_complemented=True, **kwargs):
self.scm_estimators = [scm(
random_state=random_state,
......@@ -49,6 +49,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......@@ -59,7 +60,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
end = time.time()
self.times = np.array([end-beg, 0])
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(X)) for scm_estimator in self.scm_estimators]
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(place_holder)) for scm_estimator in self.scm_estimators]
return self.scm_estimators[-1]
def predict(self, X):
......@@ -71,6 +72,7 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......
......@@ -49,6 +49,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......@@ -59,7 +60,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
end = time.time()
self.times = np.array([end-beg, 0])
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(X)) for scm_estimator in self.scm_estimators]
self.train_metrics = [zero_one_loss.score(y, scm_estimator.predict(place_holder)) for scm_estimator in self.scm_estimators]
return self.scm_estimators[-1]
def predict(self, X):
......@@ -71,6 +72,7 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
file_name = "pregen_x" + str(a) + ".csv"
while file_name in list_files:
a = int(np.random.randint(0, 10000))
file_name = "pregen_x" + str(a) + ".csv"
else:
file_name = "pregen_x"+str(a)+".csv"
np.savetxt(file_name, pregen_X, delimiter=',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment