Skip to content
Snippets Groups Projects
Commit 3aced0d5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

forgotten fits

parent 981e0c78
No related branches found
No related tags found
No related merge requests found
......@@ -49,9 +49,17 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
def predict(self, X):
pregen_X, _ = self.pregen_voters(X)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
return self.classes_[self.model_.predict(place_holder)]
def get_params(self, deep=True):
......
......@@ -35,9 +35,17 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
super(SCMPregenTree, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params)
return self
......
......@@ -42,9 +42,17 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
for scm_estimator in self.scm_estimators:
beg = time.time()
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
......
......@@ -42,9 +42,17 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name = "pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
for scm_estimator in self.scm_estimators:
beg = time.time()
scm_estimator.fit(place_holder, y, tiebreaker=None, iteration_callback=None, **fit_params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment