Skip to content
Snippets Groups Projects
Commit 981e0c78 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Avoid multiple temp file to crush themselves

parent 52c84267
No related branches found
No related tags found
No related merge requests found
......@@ -33,9 +33,17 @@ class SCMPregen(scm, BaseMonoviewClassifier, PregenClassifier):
def fit(self, X, y, tiebreaker=None, iteration_callback=None, **fit_params):
pregen_X, _ = self.pregen_voters(X, y)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name="pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
super(SCMPregen, self).fit(place_holder, y, tiebreaker=tiebreaker, iteration_callback=iteration_callback, **fit_params)
return self
......
......@@ -43,9 +43,17 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
def predict(self, X):
pregen_X, _ = self.pregen_voters(X, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name="pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
return self.classes_[self.model_.predict(place_holder)]
def get_params(self, deep=True):
......
......@@ -55,9 +55,17 @@ class SCMSparsity(BaseMonoviewClassifier, PregenClassifier):
def predict(self, X):
pregen_X, _ = self.pregen_voters(X,)
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i = 0
file_name = "pregen_x" + str(i) + ".csv"
while file_name in list_files:
i += 1
else:
file_name="pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
self.preds = [scm_estimator.predict(place_holder) for scm_estimator in self.scm_estimators]
return self.preds[-1]
......
......@@ -55,9 +55,17 @@ class SCMSparsityTree(BaseMonoviewClassifier, PregenClassifier):
def predict(self, X):
pregen_X, _ = self.pregen_voters(X, generator="Trees")
np.savetxt("pregen_x.csv", pregen_X, delimiter=',')
place_holder = np.genfromtxt("pregen_x.csv", delimiter=',')
os.remove("pregen_x.csv")
list_files = os.listdir(".")
if "pregen_x.csv" in list_files:
i=0
file_name="pregen_x"+str(i)+".csv"
while file_name in list_files:
i+=1
else:
file_name="pregen_x.csv"
np.savetxt(file_name, pregen_X, delimiter=',')
place_holder = np.genfromtxt(file_name, delimiter=',')
os.remove(file_name)
self.preds = [scm_estimator.predict(place_holder) for scm_estimator in self.scm_estimators]
return self.preds[-1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment